mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Bound filters resolving rework, filters with default argument (#727)
* bound filters resolving rework, filters with default argument * bound filters resolving rework, filters with default argument * Update 727.misc * clarification of the comment about skipping filter * fix data transfer from parent to included routers filters * fix checking containing value in generator * Update docs/dispatcher/filters/index.rst Co-authored-by: Alex Root Junior <jroot.junior@gmail.com> * Update 727.misc * reformat * better iterable types Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
parent
f97367b3ee
commit
42cba8976f
4 changed files with 170 additions and 18 deletions
3
CHANGES/727.misc
Normal file
3
CHANGES/727.misc
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
Rework filters resolving:
|
||||
* Automatically apply Bound Filters with default values to handlers
|
||||
* Fix data transfer from parent to included routers filters
|
||||
|
|
@ -2,7 +2,18 @@ from __future__ import annotations
|
|||
|
||||
import functools
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -51,7 +62,7 @@ class TelegramEventObserver:
|
|||
:param filters: positional filters
|
||||
:param bound_filters: keyword filters
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
resolved_filters = self.resolve_filters(filters, bound_filters)
|
||||
if self._handler.filters is None:
|
||||
self._handler.filters = []
|
||||
self._handler.filters.extend(
|
||||
|
|
@ -77,7 +88,7 @@ class TelegramEventObserver:
|
|||
"""
|
||||
registry: List[Type[BaseFilter]] = []
|
||||
|
||||
for router in self.router.chain:
|
||||
for router in reversed(tuple(self.router.chain)):
|
||||
observer = router.observers[self.event_name]
|
||||
|
||||
for filter_ in observer.filters:
|
||||
|
|
@ -95,22 +106,46 @@ class TelegramEventObserver:
|
|||
if outer:
|
||||
middlewares.extend(self.outer_middlewares)
|
||||
else:
|
||||
for router in reversed(list(self.router.chain_head)):
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
def resolve_filters(
|
||||
self,
|
||||
filters: Tuple[FilterType, ...],
|
||||
full_config: Dict[str, Any],
|
||||
ignore_default: bool = True,
|
||||
) -> List[BaseFilter]:
|
||||
"""
|
||||
Resolve keyword filters via filters factory
|
||||
|
||||
:param filters: positional filters
|
||||
:param full_config: keyword arguments to initialize bounded filters for router/handler
|
||||
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
|
||||
"""
|
||||
filters: List[BaseFilter] = []
|
||||
if not full_config:
|
||||
return filters
|
||||
bound_filters: List[BaseFilter] = []
|
||||
|
||||
if ignore_default and not full_config:
|
||||
return bound_filters
|
||||
|
||||
filter_types = set(type(f) for f in filters)
|
||||
|
||||
validation_errors = []
|
||||
for bound_filter in self._resolve_filters_chain():
|
||||
# skip filter if filter was used as positional filter:
|
||||
if bound_filter in filter_types:
|
||||
continue
|
||||
|
||||
# skip filter with no fields in full_config
|
||||
if ignore_default:
|
||||
full_config_keys = set(full_config.keys())
|
||||
filter_fields = set(bound_filter.__fields__.keys())
|
||||
|
||||
if not full_config_keys.intersection(filter_fields):
|
||||
continue
|
||||
|
||||
# Try to initialize filter.
|
||||
try:
|
||||
f = bound_filter(**full_config)
|
||||
|
|
@ -123,7 +158,7 @@ class TelegramEventObserver:
|
|||
for key in f.__fields__:
|
||||
full_config.pop(key, None)
|
||||
|
||||
filters.append(f)
|
||||
bound_filters.append(f)
|
||||
|
||||
if full_config:
|
||||
possible_cases = []
|
||||
|
|
@ -137,7 +172,7 @@ class TelegramEventObserver:
|
|||
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
|
||||
)
|
||||
|
||||
return filters
|
||||
return bound_filters
|
||||
|
||||
def register(
|
||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||
|
|
@ -145,7 +180,7 @@ class TelegramEventObserver:
|
|||
"""
|
||||
Register event handler
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -75,3 +75,30 @@ For example if you need to make simple text filter:
|
|||
Bound filters is always recursive propagates to the nested routers but will be available
|
||||
in nested routers only after attaching routers so that's mean you will need to
|
||||
include routers before registering handlers.
|
||||
|
||||
Resolving filters with default value
|
||||
====================================
|
||||
|
||||
Bound Filters with only default arguments will be automatically applied with default values
|
||||
to each handler in the router and nested routers to which this filter is bound.
|
||||
|
||||
For example, although we do not specify :code:`chat_type` in the handler filters,
|
||||
but since the filter has a default value, the filter will be applied to the handler
|
||||
with a default value :code:`private`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class ChatType(BaseFilter):
|
||||
chat_type: str = "private"
|
||||
|
||||
async def __call__(self, message: Message , event_chat: Chat) -> bool:
|
||||
if event_chat:
|
||||
return event_chat.type == chat_type
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
router.message.bind_filter(ChatType)
|
||||
|
||||
@router.message()
|
||||
async def my_handler(message: Message): ...
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
|
|||
pass
|
||||
|
||||
|
||||
class OptionalFilter(BaseFilter):
|
||||
optional: Optional[str]
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class DefaultFilter(BaseFilter):
|
||||
default: str = "Default"
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
def test_bind_filter(self):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
|
@ -85,26 +99,98 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter2 in filters_chain3
|
||||
assert MyFilter3 in filters_chain3
|
||||
|
||||
async def test_resolve_filters_data_from_parent_router(self):
|
||||
class FilterSet(BaseFilter):
|
||||
set_filter: bool
|
||||
|
||||
async def __call__(self, message: Message) -> dict:
|
||||
return {"test": "hello world"}
|
||||
|
||||
class FilterGet(BaseFilter):
|
||||
get_filter: bool
|
||||
|
||||
async def __call__(self, message: Message, **data) -> bool:
|
||||
assert "test" in data
|
||||
return True
|
||||
|
||||
router1 = Router(use_builtin_filters=False)
|
||||
router2 = Router(use_builtin_filters=False)
|
||||
router1.include_router(router2)
|
||||
|
||||
router1.message.bind_filter(FilterSet)
|
||||
router2.message.bind_filter(FilterGet)
|
||||
|
||||
@router2.message(set_filter=True, get_filter=True)
|
||||
def handler_test(msg: Message, test: str):
|
||||
assert test == "hello world"
|
||||
|
||||
await router1.propagate_event(
|
||||
"message",
|
||||
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
|
||||
)
|
||||
|
||||
def test_resolve_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
resolved = observer.resolve_filters({"test": "PASS"})
|
||||
resolved = observer.resolve_filters((), {"test": "PASS"})
|
||||
assert isinstance(resolved, list)
|
||||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"@bad": "very"})
|
||||
assert observer.resolve_filters((), {"@bad": "very"})
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
||||
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
|
||||
|
||||
# Bad argument type
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
assert observer.resolve_filters({"test": ...})
|
||||
assert observer.resolve_filters((), {"test": ...})
|
||||
|
||||
# Disallow same filter using
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
|
||||
|
||||
def test_dont_autoresolve_optional_filters_for_router(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
observer.filter(test="test")
|
||||
assert len(observer._handler.filters) == 1
|
||||
|
||||
def test_register_autoresolve_optional_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
assert observer.register(my_handler) == my_handler
|
||||
assert isinstance(observer.handlers[0], HandlerObject)
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert len(observer.handlers[0].filters) == 2
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok")
|
||||
assert isinstance(observer.handlers[1], HandlerObject)
|
||||
assert len(observer.handlers[1].filters) == 3
|
||||
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok", optional="ok")
|
||||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert len(observer.handlers[2].filters) == 3
|
||||
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
|
||||
|
||||
def test_register(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
|
|
@ -125,10 +211,11 @@ class TestTelegramEventObserver:
|
|||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
|
||||
|
||||
observer.register(my_handler, f, test="PASS")
|
||||
f2 = MyFilter2(test="ok")
|
||||
observer.register(my_handler, f2, test="PASS")
|
||||
assert isinstance(observer.handlers[3], HandlerObject)
|
||||
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
|
||||
assert f in callbacks
|
||||
assert f2 in callbacks
|
||||
assert MyFilter1(test="PASS") in callbacks
|
||||
|
||||
def test_register_decorator(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue