diff --git a/CHANGES/727.misc b/CHANGES/727.misc new file mode 100644 index 00000000..e595b91b --- /dev/null +++ b/CHANGES/727.misc @@ -0,0 +1,3 @@ +Rework filters resolving: +* Automatically apply Bound Filters with default values to handlers +* Fix data transfer from parent to included routers filters diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index 386d2fa4..fd9ba63a 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -2,7 +2,18 @@ from __future__ import annotations import functools from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + Union, +) from pydantic import ValidationError @@ -51,7 +62,7 @@ class TelegramEventObserver: :param filters: positional filters :param bound_filters: keyword filters """ - resolved_filters = self.resolve_filters(bound_filters) + resolved_filters = self.resolve_filters(filters, bound_filters) if self._handler.filters is None: self._handler.filters = [] self._handler.filters.extend( @@ -77,7 +88,7 @@ class TelegramEventObserver: """ registry: List[Type[BaseFilter]] = [] - for router in self.router.chain: + for router in reversed(tuple(self.router.chain)): observer = router.observers[self.event_name] for filter_ in observer.filters: @@ -95,22 +106,46 @@ class TelegramEventObserver: if outer: middlewares.extend(self.outer_middlewares) else: - for router in reversed(list(self.router.chain_head)): + for router in reversed(tuple(self.router.chain_head)): observer = router.observers[self.event_name] middlewares.extend(observer.middlewares) return middlewares - def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: + def resolve_filters( + self, + filters: Tuple[FilterType, ...], + full_config: Dict[str, Any], + ignore_default: bool = True, + ) -> List[BaseFilter]: """ Resolve keyword filters via filters factory + + :param filters: positional filters + :param full_config: keyword arguments to initialize bounded filters for router/handler + :param ignore_default: ignore to resolving filters with only default arguments that are not in full_config """ - filters: List[BaseFilter] = [] - if not full_config: - return filters + bound_filters: List[BaseFilter] = [] + + if ignore_default and not full_config: + return bound_filters + + filter_types = set(type(f) for f in filters) validation_errors = [] for bound_filter in self._resolve_filters_chain(): + # skip filter if filter was used as positional filter: + if bound_filter in filter_types: + continue + + # skip filter with no fields in full_config + if ignore_default: + full_config_keys = set(full_config.keys()) + filter_fields = set(bound_filter.__fields__.keys()) + + if not full_config_keys.intersection(filter_fields): + continue + # Try to initialize filter. try: f = bound_filter(**full_config) @@ -123,7 +158,7 @@ class TelegramEventObserver: for key in f.__fields__: full_config.pop(key, None) - filters.append(f) + bound_filters.append(f) if full_config: possible_cases = [] @@ -137,7 +172,7 @@ class TelegramEventObserver: unresolved_fields=set(full_config.keys()), possible_cases=possible_cases ) - return filters + return bound_filters def register( self, callback: HandlerType, *filters: FilterType, **bound_filters: Any @@ -145,7 +180,7 @@ class TelegramEventObserver: """ Register event handler """ - resolved_filters = self.resolve_filters(bound_filters) + resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False) self.handlers.append( HandlerObject( callback=callback, diff --git a/docs/dispatcher/filters/index.rst b/docs/dispatcher/filters/index.rst index 5a841c1f..0ce013ab 100644 --- a/docs/dispatcher/filters/index.rst +++ b/docs/dispatcher/filters/index.rst @@ -75,3 +75,30 @@ For example if you need to make simple text filter: Bound filters is always recursive propagates to the nested routers but will be available in nested routers only after attaching routers so that's mean you will need to include routers before registering handlers. + +Resolving filters with default value +==================================== + +Bound Filters with only default arguments will be automatically applied with default values +to each handler in the router and nested routers to which this filter is bound. + +For example, although we do not specify :code:`chat_type` in the handler filters, +but since the filter has a default value, the filter will be applied to the handler +with a default value :code:`private`: + +.. code-block:: python + + class ChatType(BaseFilter): + chat_type: str = "private" + + async def __call__(self, message: Message , event_chat: Chat) -> bool: + if event_chat: + return event_chat.type == chat_type + else: + return False + + + router.message.bind_filter(ChatType) + + @router.message() + async def my_handler(message: Message): ... diff --git a/tests/test_dispatcher/test_event/test_telegram.py b/tests/test_dispatcher/test_event/test_telegram.py index 563ffa9e..4d0e11e1 100644 --- a/tests/test_dispatcher/test_event/test_telegram.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -1,6 +1,6 @@ import datetime import functools -from typing import Any, Awaitable, Callable, Dict, NoReturn, Union +from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union import pytest @@ -45,6 +45,20 @@ class MyFilter3(MyFilter1): pass +class OptionalFilter(BaseFilter): + optional: Optional[str] + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + +class DefaultFilter(BaseFilter): + default: str = "Default" + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + class TestTelegramEventObserver: def test_bind_filter(self): event_observer = TelegramEventObserver(Router(), "test") @@ -85,26 +99,98 @@ class TestTelegramEventObserver: assert MyFilter2 in filters_chain3 assert MyFilter3 in filters_chain3 + async def test_resolve_filters_data_from_parent_router(self): + class FilterSet(BaseFilter): + set_filter: bool + + async def __call__(self, message: Message) -> dict: + return {"test": "hello world"} + + class FilterGet(BaseFilter): + get_filter: bool + + async def __call__(self, message: Message, **data) -> bool: + assert "test" in data + return True + + router1 = Router(use_builtin_filters=False) + router2 = Router(use_builtin_filters=False) + router1.include_router(router2) + + router1.message.bind_filter(FilterSet) + router2.message.bind_filter(FilterGet) + + @router2.message(set_filter=True, get_filter=True) + def handler_test(msg: Message, test: str): + assert test == "hello world" + + await router1.propagate_event( + "message", + Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")), + ) + def test_resolve_filters(self): router = Router(use_builtin_filters=False) observer = router.message observer.bind_filter(MyFilter1) - resolved = observer.resolve_filters({"test": "PASS"}) + resolved = observer.resolve_filters((), {"test": "PASS"}) assert isinstance(resolved, list) assert any(isinstance(item, MyFilter1) for item in resolved) # Unknown filter with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): - assert observer.resolve_filters({"@bad": "very"}) + assert observer.resolve_filters((), {"@bad": "very"}) # Unknown filter with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): - assert observer.resolve_filters({"test": "ok", "@bad": "very"}) + assert observer.resolve_filters((), {"test": "ok", "@bad": "very"}) # Bad argument type with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"): - assert observer.resolve_filters({"test": ...}) + assert observer.resolve_filters((), {"test": ...}) + + # Disallow same filter using + with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"): + observer.resolve_filters((MyFilter1(test="test"),), {"test": ...}) + + def test_dont_autoresolve_optional_filters_for_router(self): + router = Router(use_builtin_filters=False) + observer = router.message + observer.bind_filter(MyFilter1) + observer.bind_filter(OptionalFilter) + observer.bind_filter(DefaultFilter) + + observer.filter(test="test") + assert len(observer._handler.filters) == 1 + + def test_register_autoresolve_optional_filters(self): + router = Router(use_builtin_filters=False) + observer = router.message + observer.bind_filter(MyFilter1) + observer.bind_filter(OptionalFilter) + observer.bind_filter(DefaultFilter) + + assert observer.register(my_handler) == my_handler + assert isinstance(observer.handlers[0], HandlerObject) + assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter) + assert len(observer.handlers[0].filters) == 2 + assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter) + assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter) + + observer.register(my_handler, test="ok") + assert isinstance(observer.handlers[1], HandlerObject) + assert len(observer.handlers[1].filters) == 3 + assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1) + assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter) + assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter) + + observer.register(my_handler, test="ok", optional="ok") + assert isinstance(observer.handlers[2], HandlerObject) + assert len(observer.handlers[2].filters) == 3 + assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1) + assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter) + assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter) def test_register(self): router = Router(use_builtin_filters=False) @@ -125,10 +211,11 @@ class TestTelegramEventObserver: assert isinstance(observer.handlers[2], HandlerObject) assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters) - observer.register(my_handler, f, test="PASS") + f2 = MyFilter2(test="ok") + observer.register(my_handler, f2, test="PASS") assert isinstance(observer.handlers[3], HandlerObject) callbacks = [filter_.callback for filter_ in observer.handlers[3].filters] - assert f in callbacks + assert f2 in callbacks assert MyFilter1(test="PASS") in callbacks def test_register_decorator(self):