diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 299c8938..93115ab7 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -1,7 +1,17 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Any, Dict, List, Type +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + List, + Optional, + Type, +) from pydantic import ValidationError @@ -21,16 +31,12 @@ class EventObserver: Base events observer """ - def __init__(self): + def __init__(self) -> None: self.handlers: List[HandlerObject] = [] - def register(self, callback: HandlerType, *filters: FilterType): + def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType: """ Register callback with filters - - :param callback: - :param filters: - :return: """ self.handlers.append( HandlerObject( @@ -39,7 +45,11 @@ class EventObserver: ) return callback - async def trigger(self, *args, **kwargs): + async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + """ + Propagate event to handlers. + Handler will be called when all its filters is pass. + """ for handler in self.handlers: kwargs_copy = copy.copy(kwargs) result, data = await handler.check(*args, **kwargs) @@ -50,8 +60,12 @@ class EventObserver: except SkipHandler: continue - def __call__(self, *args: FilterType): - def wrapper(callback: CallbackType): + def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]: + """ + Decorator for registering event handlers + """ + + def wrapper(callback: CallbackType) -> CallbackType: self.register(callback, *args) return callback @@ -63,7 +77,7 @@ class TelegramEventObserver(EventObserver): Event observer for Telegram events """ - def __init__(self, router: Router, event_name: str): + def __init__(self, router: Router, event_name: str) -> None: super().__init__() self.router: Router = router @@ -71,16 +85,25 @@ class TelegramEventObserver(EventObserver): self.filters: List[Type[BaseFilter]] = [] def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: + """ + Register filter class in factory + + :param bound_filter: + """ if not issubclass(bound_filter, BaseFilter): raise TypeError( "bound_filter() argument 'bound_filter' must be subclass of BaseFilter" ) self.filters.append(bound_filter) - def _resolve_filters_chain(self): - registry: List[FilterType] = [] + def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]: + """ + Get all bounded filters from current observer and from the parents + with the same event type without duplicates + """ + registry: List[Type[BaseFilter]] = [] - router = self.router + router: Optional[Router] = self.router while router: observer = router.observers[self.event_name] router = router.parent_router @@ -91,16 +114,10 @@ class TelegramEventObserver(EventObserver): yield filter_ registry.append(filter_) - def register(self, callback: HandlerType, *filters: FilterType, **bound_filters: Any): - resolved_filters = self.resolve_filters(bound_filters) - return super().register(callback, *filters, *resolved_filters) - - async def trigger(self, *args, **kwargs): - async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs): - yield result - break - def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: + """ + Resolve keyword filters via filters factory + """ filters: List[BaseFilter] = [] if not full_config: return filters @@ -112,19 +129,44 @@ class TelegramEventObserver(EventObserver): except ValidationError: continue - # Clean full config to prevent to re-initialize another filter with the same configuration + # Clean full config to prevent to re-initialize another filter + # with the same configuration for key in f.__fields__: full_config.pop(key, None) filters.append(f) if full_config: - raise ValueError(f"Unknown filters: {set(full_config.keys())}") + raise ValueError(f"Unknown keyword filters: {set(full_config.keys())}") return filters - def __call__(self, *args: FilterType, **bound_filters): - def wrapper(callback: CallbackType): + def register( + self, callback: HandlerType, *filters: FilterType, **bound_filters: Any + ) -> HandlerType: + """ + Register event handler + """ + resolved_filters = self.resolve_filters(bound_filters) + return super().register(callback, *filters, *resolved_filters) + + async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + """ + Propagate event to handlers and stops propagation on first match. + Handler will be called when all its filters is pass. + """ + async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs): + yield result + break + + def __call__( + self, *args: FilterType, **bound_filters: BaseFilter + ) -> Callable[[CallbackType], CallbackType]: + """ + Decorator for registering event handlers + """ + + def wrapper(callback: CallbackType) -> CallbackType: self.register(callback, *args, **bound_filters) return callback diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index a6978375..2f157850 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -177,15 +177,15 @@ class TestTelegramEventObserver: assert any(isinstance(item, MyFilter1) for item in resolved) # Unknown filter - with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"): + with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"): assert observer.resolve_filters({"@bad": "very"}) # Unknown filter - with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"): + with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"): assert observer.resolve_filters({"test": "ok", "@bad": "very"}) # Bad argument type - with pytest.raises(ValueError, match="Unknown filters: {'test'}"): + with pytest.raises(ValueError, match="Unknown keyword filters: {'test'}"): assert observer.resolve_filters({"test": ...}) def test_register(self):