diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 4d761c95..0d2380c9 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,19 +1,19 @@ # Frequenz channels Release Notes -## Upgrading +## New Features -- `FileWatcher`: The file polling mechanism is now forced by default. This provides reliable and consistent file monitoring on network file systems (e.g., CIFS). However, it may have a performance impact on local file systems or when monitoring a large number of files. - - To disable file polling, set the `force_polling` parameter to `False`. - - The `polling_interval` parameter defines the interval for polling changes. This is relevant only when polling is enabled and defaults to 1 second. +- There is a new `Receiver.triggered` method that can be used instead of `selected_from`: -## New Features + ```python + async for selected in select(recv1, recv2): + if recv1.triggered(selected): + print('Received from recv1:', selected.message) + if recv2.triggered(selected): + print('Received from recv2:', selected.message) + ``` -- `Timer.reset()` now supports setting the interval and will restart the timer with the new interval. +* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used. ## Bug Fixes -- `FileWatcher`: - - Fixed `ready()` method to return False when an error occurs. Before this fix, `select()` (and other code using `ready()`) never detected the `FileWatcher` was stopped and the `select()` loop was continuously waking up to inform the receiver was ready. - - Reports file events correctly on network file systems like CIFS. - -- `Timer.stop()` and `Timer.reset()` now immediately stop the timer if it is running. Before this fix, the timer would continue to run until the next interval. +- Fixed a memory leak in the timer. diff --git a/benchmarks/benchmark_broadcast.py b/benchmarks/benchmark_broadcast.py index 08181dfc..218fbe4a 100644 --- a/benchmarks/benchmark_broadcast.py +++ b/benchmarks/benchmark_broadcast.py @@ -141,7 +141,7 @@ def time_async_task(task: Coroutine[Any, Any, int]) -> tuple[float, Any]: return timeit.default_timer() - start, ret -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments,too-many-positional-arguments def run_one( benchmark_method: Callable[[int, int, int], Coroutine[Any, Any, int]], num_channels: int, diff --git a/pyproject.toml b/pyproject.toml index 12072ce0..523b3a21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dev-flake8 = [ "flake8 == 7.1.1", "flake8-docstrings == 1.7.0", "flake8-pyproject == 1.2.3", # For reading the flake8 config from pyproject.toml - "pydoclint == 0.5.6", + "pydoclint == 0.5.9", "pydocstyle == 6.3.0", ] dev-formatting = ["black == 24.8.0", "isort == 5.13.2"] @@ -52,12 +52,12 @@ dev-mkdocs = [ "markdown-svgbob == 202406.1023", "mike == 2.1.3", "mkdocs-gen-files == 0.5.0", - "mkdocs-include-markdown-plugin == 6.2.2", + "mkdocs-include-markdown-plugin == 7.0.0", "mkdocs-literate-nav == 0.6.1", - "mkdocs-macros-plugin == 1.0.5", - "mkdocs-material == 9.5.34", - "mkdocstrings[python] == 0.26.0", - "mkdocstrings-python == 1.10.9", + "mkdocs-macros-plugin == 1.2.0", + "mkdocs-material == 9.5.39", + "mkdocstrings[python] == 0.26.1", + "mkdocstrings-python == 1.11.1", "pymdownx-superfence-filter-lines == 0.1.0", ] dev-mypy = [ @@ -70,13 +70,13 @@ dev-noxfile = ["nox == 2024.4.15", "frequenz-repo-config[lib] == 0.10.0"] dev-pylint = [ # For checking the noxfile, docs/ script, and tests "frequenz-channels[dev-mkdocs,dev-noxfile,dev-pytest]", - "pylint == 3.2.7", + "pylint == 3.3.1", ] dev-pytest = [ "async-solipsism == 0.7", "frequenz-repo-config[extra-lint-examples] == 0.10.0", - "hypothesis == 6.111.2", - "pytest == 8.3.2", + "hypothesis == 6.112.2", + "pytest == 8.3.3", "pytest-asyncio == 0.24.0", "pytest-mock == 3.14.0", ] diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index 2ea05156..d08efc4e 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -425,7 +425,6 @@ def consume(self) -> _T: Raises: ReceiverStoppedError: If the receiver stopped producing messages. - ReceiverError: If there is some problem with the receiver. """ if ( # pylint: disable=protected-access self._next is _Empty and self._channel._closed diff --git a/src/frequenz/channels/_merge.py b/src/frequenz/channels/_merge.py index 6be43025..3a306dbb 100644 --- a/src/frequenz/channels/_merge.py +++ b/src/frequenz/channels/_merge.py @@ -179,7 +179,6 @@ def consume(self) -> ReceiverMessageT_co: Raises: ReceiverStoppedError: If the receiver stopped producing messages. - ReceiverError: If there is some problem with the receiver. """ if not self._results and not self._pending: raise ReceiverStoppedError(self) diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index c01c8102..53862a45 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -155,16 +155,24 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Generic, Self +from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload from ._exceptions import Error from ._generic import MappedMessageT_co, ReceiverMessageT_co +if TYPE_CHECKING: + from ._select import Selected + +FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True) +"""Type variable for the filtered message type.""" + class Receiver(ABC, Generic[ReceiverMessageT_co]): """An endpoint to receive messages.""" - async def __anext__(self) -> ReceiverMessageT_co: + # We need the noqa here because ReceiverError can be raised by ready() and consume() + # implementations. + async def __anext__(self) -> ReceiverMessageT_co: # noqa: DOC503 """Await the next message in the async iteration over received messages. Returns: @@ -215,7 +223,9 @@ def __aiter__(self) -> Self: """ return self - async def receive(self) -> ReceiverMessageT_co: + # We need the noqa here because ReceiverError can be raised by consume() + # implementations. + async def receive(self) -> ReceiverMessageT_co: # noqa: DOC503 """Receive a message. Returns: @@ -226,19 +236,18 @@ async def receive(self) -> ReceiverMessageT_co: ReceiverError: If there is some problem with the receiver. """ try: - received = await self.__anext__() # pylint: disable=unnecessary-dunder-call + received = await anext(self) except StopAsyncIteration as exc: # If we already had a cause and it was the receiver was stopped, # then reuse that error, as StopAsyncIteration is just an artifact # introduced by __anext__. if ( isinstance(exc.__cause__, ReceiverStoppedError) - # pylint is not smart enough to figure out we checked above - # this is a ReceiverStoppedError and thus it does have - # a receiver member - and exc.__cause__.receiver is self # pylint: disable=no-member + and exc.__cause__.receiver is self ): - raise exc.__cause__ + # This is a false positive, we are actually checking __cause__ is a + # ReceiverStoppedError which is an exception. + raise exc.__cause__ # pylint: disable=raising-non-exception raise ReceiverStoppedError(self) from exc return received @@ -261,11 +270,66 @@ def map( """ return _Mapper(receiver=self, mapping_function=mapping_function) + @overload + def filter( + self, + filter_function: Callable[ + [ReceiverMessageT_co], TypeGuard[FilteredMessageT_co] + ], + /, + ) -> Receiver[FilteredMessageT_co]: + """Apply a type guard on the messages on a receiver. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + filter_function: The function to be applied on incoming messages to + determine if they should be received. + + Returns: + A new receiver that only receives messages that pass the filter. + """ + ... # pylint: disable=unnecessary-ellipsis + + @overload def filter( self, filter_function: Callable[[ReceiverMessageT_co], bool], / ) -> Receiver[ReceiverMessageT_co]: """Apply a filter function on the messages on a receiver. + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + filter_function: The function to be applied on incoming messages to + determine if they should be received. + + Returns: + A new receiver that only receives messages that pass the filter. + """ + ... # pylint: disable=unnecessary-ellipsis + + def filter( + self, + filter_function: ( + Callable[[ReceiverMessageT_co], bool] + | Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]] + ), + /, + ) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]: + """Apply a filter function on the messages on a receiver. + + Note: + You can pass a [type guard][typing.TypeGuard] as the filter function to + narrow the type of the messages that pass the filter. + Tip: The returned receiver type won't have all the methods of the original receiver. If you need to access methods of the original receiver that are @@ -281,6 +345,30 @@ def filter( """ return _Filter(receiver=self, filter_function=filter_function) + def triggered( + self, selected: Selected[Any] + ) -> TypeGuard[Selected[ReceiverMessageT_co]]: + """Check whether this receiver was selected by [`select()`][frequenz.channels.select]. + + This method is used in conjunction with the + [`Selected`][frequenz.channels.Selected] class to determine which receiver was + selected in `select()` iteration. + + It also works as a [type guard][typing.TypeGuard] to narrow the type of the + `Selected` instance to the type of the receiver. + + Please see [`select()`][frequenz.channels.select] for an example. + + Args: + selected: The result of a `select()` iteration. + + Returns: + Whether this receiver was selected. + """ + if handled := selected._recv is self: # pylint: disable=protected-access + selected._handled = True # pylint: disable=protected-access + return handled + class ReceiverError(Error, Generic[ReceiverMessageT_co]): """An error that originated in a [Receiver][frequenz.channels.Receiver]. @@ -370,9 +458,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 ReceiverStoppedError: If the receiver stopped producing messages. ReceiverError: If there is a problem with the receiver. """ - return self._mapping_function( - self._receiver.consume() - ) # pylint: disable=protected-access + return self._mapping_function(self._receiver.consume()) def __str__(self) -> str: """Return a string representation of the mapper.""" @@ -450,7 +536,6 @@ def consume(self) -> ReceiverMessageT_co: Raises: ReceiverStoppedError: If the receiver stopped producing messages. - ReceiverError: If there is a problem with the receiver. """ if self._recv_closed: raise ReceiverStoppedError(self) diff --git a/src/frequenz/channels/_select.py b/src/frequenz/channels/_select.py index e8b6e803..ccd669eb 100644 --- a/src/frequenz/channels/_select.py +++ b/src/frequenz/channels/_select.py @@ -196,8 +196,10 @@ def __init__(self, receiver: Receiver[ReceiverMessageT_co], /) -> None: self._handled: bool = False """Flag to indicate if this selected has been handled in the if-chain.""" + # We need the noqa here because pydoclint can't figure out raise self._exception + # actually raise an Exception. @property - def message(self) -> ReceiverMessageT_co: + def message(self) -> ReceiverMessageT_co: # noqa: DOC503 """The message that was received, if any. Returns: @@ -267,9 +269,7 @@ def selected_from( Returns: Whether the given receiver was selected. """ - if handled := selected._recv is receiver: # pylint: disable=protected-access - selected._handled = True # pylint: disable=protected-access - return handled + return receiver.triggered(selected) class SelectError(Error): @@ -339,7 +339,11 @@ def __init__(self, selected: Selected[ReceiverMessageT_co]) -> None: # https://github.com/python/mypy/issues/13597 -async def select(*receivers: Receiver[Any]) -> AsyncIterator[Selected[Any]]: +# We need the noqa here because BaseExceptionGroup can be raised indirectly by +# _stop_pending_tasks. +async def select( # noqa: DOC503 + *receivers: Receiver[Any], +) -> AsyncIterator[Selected[Any]]: """Iterate over the messages of all receivers as they receive new messages. This function is used to iterate over the messages of all receivers as they receive @@ -372,14 +376,14 @@ async def select(*receivers: Receiver[Any]) -> AsyncIterator[Selected[Any]]: import datetime from typing import assert_never - from frequenz.channels import ReceiverStoppedError, select, selected_from + from frequenz.channels import ReceiverStoppedError, select from frequenz.channels.timer import SkipMissedAndDrift, Timer, TriggerAllMissed timer1 = Timer(datetime.timedelta(seconds=1), TriggerAllMissed()) timer2 = Timer(datetime.timedelta(seconds=0.5), SkipMissedAndDrift()) async for selected in select(timer1, timer2): - if selected_from(selected, timer1): + if timer1.triggered(selected): # Beware: `selected.message` might raise an exception, you can always # check for exceptions with `selected.exception` first or use # a try-except block. You can also quickly check if the receiver was @@ -389,7 +393,7 @@ async def select(*receivers: Receiver[Any]) -> AsyncIterator[Selected[Any]]: continue print(f"timer1: now={datetime.datetime.now()} drift={selected.message}") timer2.stop() - elif selected_from(selected, timer2): + elif timer2.triggered(selected): # Explicitly handling of exceptions match selected.exception: case ReceiverStoppedError(): diff --git a/src/frequenz/channels/event.py b/src/frequenz/channels/event.py index 3f3b207a..0d599e89 100644 --- a/src/frequenz/channels/event.py +++ b/src/frequenz/channels/event.py @@ -16,10 +16,10 @@ import asyncio as _asyncio -from frequenz.channels import _receiver +from frequenz.channels._receiver import Receiver, ReceiverStoppedError -class Event(_receiver.Receiver[None]): +class Event(Receiver[None]): """A receiver that can be made ready directly. # Usage @@ -161,7 +161,7 @@ def consume(self) -> None: ReceiverStoppedError: If this receiver is stopped. """ if not self._is_set and self._is_stopped: - raise _receiver.ReceiverStoppedError(self) + raise ReceiverStoppedError(self) assert self._is_set, "calls to `consume()` must be follow a call to `ready()`" diff --git a/src/frequenz/channels/timer.py b/src/frequenz/channels/timer.py index b2f0338a..2785feea 100644 --- a/src/frequenz/channels/timer.py +++ b/src/frequenz/channels/timer.py @@ -466,7 +466,8 @@ class Timer(Receiver[timedelta]): depending on the chosen policy. """ - def __init__( # pylint: disable=too-many-arguments + # We need the noqa here because RuntimeError is raised indirectly. + def __init__( # noqa: DOC503 pylint: disable=too-many-arguments self, interval: timedelta, missed_tick_policy: MissedTickPolicy, @@ -586,7 +587,8 @@ def is_running(self) -> bool: """Whether the timer is running.""" return not self._stopped - def reset( + # We need the noqa here because RuntimeError is raised indirectly. + def reset( # noqa: DOC503 self, *, interval: timedelta | None = None, @@ -682,14 +684,20 @@ async def ready(self) -> bool: # noqa: DOC502 # could be reset while we are sleeping, in which case we need to recalculate # the time to the next tick and try again. while time_to_next_tick > 0: - await next( - asyncio.as_completed( - [ - asyncio.sleep(time_to_next_tick / 1_000_000), - self._reset_event.wait(), - ] - ) + _, pending = await asyncio.wait( + [ + asyncio.create_task(asyncio.sleep(time_to_next_tick / 1_000_000)), + asyncio.create_task(self._reset_event.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._reset_event.clear() now = self._now() time_to_next_tick = self._next_tick_time - now diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index f480d194..c8a2e9cf 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -6,6 +6,7 @@ import asyncio from dataclasses import dataclass +from typing import TypeGuard, assert_never import pytest @@ -248,6 +249,31 @@ async def test_broadcast_filter() -> None: assert (await receiver.receive()) == 15 +async def test_broadcast_filter_type_guard() -> None: + """Ensure filter type guard works.""" + chan = Broadcast[int | str](name="input-chan") + sender = chan.new_sender() + + def _is_int(num: int | str) -> TypeGuard[int]: + return isinstance(num, int) + + # filter out objects that are not integers. + receiver = chan.new_receiver().filter(_is_int) + + await sender.send("hello") + await sender.send(8) + + message = await receiver.receive() + assert message == 8 + is_int = False + match message: + case int(): + is_int = True + case unexpected: + assert_never(unexpected) + assert is_int + + async def test_broadcast_receiver_drop() -> None: """Ensure deleted receivers get cleaned up.""" chan = Broadcast[int](name="input-chan") diff --git a/tests/test_file_watcher_integration.py b/tests/test_file_watcher_integration.py index 85c7f655..ef1846e6 100644 --- a/tests/test_file_watcher_integration.py +++ b/tests/test_file_watcher_integration.py @@ -10,7 +10,7 @@ import pytest from frequenz.channels import ReceiverStoppedError, select, selected_from -from frequenz.channels.file_watcher import Event, EventType, FileWatcher +from frequenz.channels.file_watcher import EventType, FileWatcher from frequenz.channels.timer import SkipMissedAndDrift, Timer @@ -26,7 +26,9 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None: number_of_writes = 0 expected_number_of_writes = 3 - file_watcher = FileWatcher(paths=[str(tmp_path)]) + file_watcher = FileWatcher( + paths=[str(tmp_path)], polling_interval=timedelta(seconds=0.05) + ) timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift()) async for selected in select(file_watcher, timer): @@ -34,11 +36,18 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None: filename.write_text(f"{selected.message}") elif selected_from(selected, file_watcher): event_type = EventType.CREATE if number_of_writes == 0 else EventType.MODIFY - assert selected.message == Event(type=event_type, path=filename) - number_of_writes += 1 - # After receiving a write 3 times, unsubscribe from the writes channel - if number_of_writes == expected_number_of_writes: - break + event = selected.message + # If we receive updates for the directory itself, they should be only + # modifications, we only check that because we can have ordering issues if + # we try check also the order compared to events in the file. + if event.path == tmp_path: + assert event.type == EventType.MODIFY + elif event.path == filename: + assert event.type == event_type + number_of_writes += 1 + # After receiving a write 3 times, unsubscribe from the writes channel + if number_of_writes == expected_number_of_writes: + break assert number_of_writes == expected_number_of_writes @@ -58,6 +67,7 @@ async def test_file_watcher_deletes(tmp_path: pathlib.Path) -> None: paths=[str(tmp_path)], event_types={EventType.DELETE}, force_polling=False, + polling_interval=timedelta(seconds=0.05), ) write_timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift()) deletion_timer = Timer(timedelta(seconds=0.25), SkipMissedAndDrift())