mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Bound filters resolving rework, filters with default argument (#727)
* bound filters resolving rework, filters with default argument * bound filters resolving rework, filters with default argument * Update 727.misc * clarification of the comment about skipping filter * fix data transfer from parent to included routers filters * fix checking containing value in generator * Update docs/dispatcher/filters/index.rst Co-authored-by: Alex Root Junior <jroot.junior@gmail.com> * Update 727.misc * reformat * better iterable types Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
parent
f97367b3ee
commit
42cba8976f
4 changed files with 170 additions and 18 deletions
3
CHANGES/727.misc
Normal file
3
CHANGES/727.misc
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
Rework filters resolving:
|
||||||
|
* Automatically apply Bound Filters with default values to handlers
|
||||||
|
* Fix data transfer from parent to included routers filters
|
||||||
|
|
@ -2,7 +2,18 @@ from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|
@ -51,7 +62,7 @@ class TelegramEventObserver:
|
||||||
:param filters: positional filters
|
:param filters: positional filters
|
||||||
:param bound_filters: keyword filters
|
:param bound_filters: keyword filters
|
||||||
"""
|
"""
|
||||||
resolved_filters = self.resolve_filters(bound_filters)
|
resolved_filters = self.resolve_filters(filters, bound_filters)
|
||||||
if self._handler.filters is None:
|
if self._handler.filters is None:
|
||||||
self._handler.filters = []
|
self._handler.filters = []
|
||||||
self._handler.filters.extend(
|
self._handler.filters.extend(
|
||||||
|
|
@ -77,7 +88,7 @@ class TelegramEventObserver:
|
||||||
"""
|
"""
|
||||||
registry: List[Type[BaseFilter]] = []
|
registry: List[Type[BaseFilter]] = []
|
||||||
|
|
||||||
for router in self.router.chain:
|
for router in reversed(tuple(self.router.chain)):
|
||||||
observer = router.observers[self.event_name]
|
observer = router.observers[self.event_name]
|
||||||
|
|
||||||
for filter_ in observer.filters:
|
for filter_ in observer.filters:
|
||||||
|
|
@ -95,22 +106,46 @@ class TelegramEventObserver:
|
||||||
if outer:
|
if outer:
|
||||||
middlewares.extend(self.outer_middlewares)
|
middlewares.extend(self.outer_middlewares)
|
||||||
else:
|
else:
|
||||||
for router in reversed(list(self.router.chain_head)):
|
for router in reversed(tuple(self.router.chain_head)):
|
||||||
observer = router.observers[self.event_name]
|
observer = router.observers[self.event_name]
|
||||||
middlewares.extend(observer.middlewares)
|
middlewares.extend(observer.middlewares)
|
||||||
|
|
||||||
return middlewares
|
return middlewares
|
||||||
|
|
||||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
def resolve_filters(
|
||||||
|
self,
|
||||||
|
filters: Tuple[FilterType, ...],
|
||||||
|
full_config: Dict[str, Any],
|
||||||
|
ignore_default: bool = True,
|
||||||
|
) -> List[BaseFilter]:
|
||||||
"""
|
"""
|
||||||
Resolve keyword filters via filters factory
|
Resolve keyword filters via filters factory
|
||||||
|
|
||||||
|
:param filters: positional filters
|
||||||
|
:param full_config: keyword arguments to initialize bounded filters for router/handler
|
||||||
|
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
|
||||||
"""
|
"""
|
||||||
filters: List[BaseFilter] = []
|
bound_filters: List[BaseFilter] = []
|
||||||
if not full_config:
|
|
||||||
return filters
|
if ignore_default and not full_config:
|
||||||
|
return bound_filters
|
||||||
|
|
||||||
|
filter_types = set(type(f) for f in filters)
|
||||||
|
|
||||||
validation_errors = []
|
validation_errors = []
|
||||||
for bound_filter in self._resolve_filters_chain():
|
for bound_filter in self._resolve_filters_chain():
|
||||||
|
# skip filter if filter was used as positional filter:
|
||||||
|
if bound_filter in filter_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip filter with no fields in full_config
|
||||||
|
if ignore_default:
|
||||||
|
full_config_keys = set(full_config.keys())
|
||||||
|
filter_fields = set(bound_filter.__fields__.keys())
|
||||||
|
|
||||||
|
if not full_config_keys.intersection(filter_fields):
|
||||||
|
continue
|
||||||
|
|
||||||
# Try to initialize filter.
|
# Try to initialize filter.
|
||||||
try:
|
try:
|
||||||
f = bound_filter(**full_config)
|
f = bound_filter(**full_config)
|
||||||
|
|
@ -123,7 +158,7 @@ class TelegramEventObserver:
|
||||||
for key in f.__fields__:
|
for key in f.__fields__:
|
||||||
full_config.pop(key, None)
|
full_config.pop(key, None)
|
||||||
|
|
||||||
filters.append(f)
|
bound_filters.append(f)
|
||||||
|
|
||||||
if full_config:
|
if full_config:
|
||||||
possible_cases = []
|
possible_cases = []
|
||||||
|
|
@ -137,7 +172,7 @@ class TelegramEventObserver:
|
||||||
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
|
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
|
||||||
)
|
)
|
||||||
|
|
||||||
return filters
|
return bound_filters
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||||
|
|
@ -145,7 +180,7 @@ class TelegramEventObserver:
|
||||||
"""
|
"""
|
||||||
Register event handler
|
Register event handler
|
||||||
"""
|
"""
|
||||||
resolved_filters = self.resolve_filters(bound_filters)
|
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
|
||||||
self.handlers.append(
|
self.handlers.append(
|
||||||
HandlerObject(
|
HandlerObject(
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
|
|
||||||
|
|
@ -75,3 +75,30 @@ For example if you need to make simple text filter:
|
||||||
Bound filters is always recursive propagates to the nested routers but will be available
|
Bound filters is always recursive propagates to the nested routers but will be available
|
||||||
in nested routers only after attaching routers so that's mean you will need to
|
in nested routers only after attaching routers so that's mean you will need to
|
||||||
include routers before registering handlers.
|
include routers before registering handlers.
|
||||||
|
|
||||||
|
Resolving filters with default value
|
||||||
|
====================================
|
||||||
|
|
||||||
|
Bound Filters with only default arguments will be automatically applied with default values
|
||||||
|
to each handler in the router and nested routers to which this filter is bound.
|
||||||
|
|
||||||
|
For example, although we do not specify :code:`chat_type` in the handler filters,
|
||||||
|
but since the filter has a default value, the filter will be applied to the handler
|
||||||
|
with a default value :code:`private`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class ChatType(BaseFilter):
|
||||||
|
chat_type: str = "private"
|
||||||
|
|
||||||
|
async def __call__(self, message: Message , event_chat: Chat) -> bool:
|
||||||
|
if event_chat:
|
||||||
|
return event_chat.type == chat_type
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
router.message.bind_filter(ChatType)
|
||||||
|
|
||||||
|
@router.message()
|
||||||
|
async def my_handler(message: Message): ...
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
|
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OptionalFilter(BaseFilter):
|
||||||
|
optional: Optional[str]
|
||||||
|
|
||||||
|
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFilter(BaseFilter):
|
||||||
|
default: str = "Default"
|
||||||
|
|
||||||
|
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TestTelegramEventObserver:
|
class TestTelegramEventObserver:
|
||||||
def test_bind_filter(self):
|
def test_bind_filter(self):
|
||||||
event_observer = TelegramEventObserver(Router(), "test")
|
event_observer = TelegramEventObserver(Router(), "test")
|
||||||
|
|
@ -85,26 +99,98 @@ class TestTelegramEventObserver:
|
||||||
assert MyFilter2 in filters_chain3
|
assert MyFilter2 in filters_chain3
|
||||||
assert MyFilter3 in filters_chain3
|
assert MyFilter3 in filters_chain3
|
||||||
|
|
||||||
|
async def test_resolve_filters_data_from_parent_router(self):
|
||||||
|
class FilterSet(BaseFilter):
|
||||||
|
set_filter: bool
|
||||||
|
|
||||||
|
async def __call__(self, message: Message) -> dict:
|
||||||
|
return {"test": "hello world"}
|
||||||
|
|
||||||
|
class FilterGet(BaseFilter):
|
||||||
|
get_filter: bool
|
||||||
|
|
||||||
|
async def __call__(self, message: Message, **data) -> bool:
|
||||||
|
assert "test" in data
|
||||||
|
return True
|
||||||
|
|
||||||
|
router1 = Router(use_builtin_filters=False)
|
||||||
|
router2 = Router(use_builtin_filters=False)
|
||||||
|
router1.include_router(router2)
|
||||||
|
|
||||||
|
router1.message.bind_filter(FilterSet)
|
||||||
|
router2.message.bind_filter(FilterGet)
|
||||||
|
|
||||||
|
@router2.message(set_filter=True, get_filter=True)
|
||||||
|
def handler_test(msg: Message, test: str):
|
||||||
|
assert test == "hello world"
|
||||||
|
|
||||||
|
await router1.propagate_event(
|
||||||
|
"message",
|
||||||
|
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
|
||||||
|
)
|
||||||
|
|
||||||
def test_resolve_filters(self):
|
def test_resolve_filters(self):
|
||||||
router = Router(use_builtin_filters=False)
|
router = Router(use_builtin_filters=False)
|
||||||
observer = router.message
|
observer = router.message
|
||||||
observer.bind_filter(MyFilter1)
|
observer.bind_filter(MyFilter1)
|
||||||
|
|
||||||
resolved = observer.resolve_filters({"test": "PASS"})
|
resolved = observer.resolve_filters((), {"test": "PASS"})
|
||||||
assert isinstance(resolved, list)
|
assert isinstance(resolved, list)
|
||||||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||||
|
|
||||||
# Unknown filter
|
# Unknown filter
|
||||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||||
assert observer.resolve_filters({"@bad": "very"})
|
assert observer.resolve_filters((), {"@bad": "very"})
|
||||||
|
|
||||||
# Unknown filter
|
# Unknown filter
|
||||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
|
||||||
|
|
||||||
# Bad argument type
|
# Bad argument type
|
||||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||||
assert observer.resolve_filters({"test": ...})
|
assert observer.resolve_filters((), {"test": ...})
|
||||||
|
|
||||||
|
# Disallow same filter using
|
||||||
|
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||||
|
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
|
||||||
|
|
||||||
|
def test_dont_autoresolve_optional_filters_for_router(self):
|
||||||
|
router = Router(use_builtin_filters=False)
|
||||||
|
observer = router.message
|
||||||
|
observer.bind_filter(MyFilter1)
|
||||||
|
observer.bind_filter(OptionalFilter)
|
||||||
|
observer.bind_filter(DefaultFilter)
|
||||||
|
|
||||||
|
observer.filter(test="test")
|
||||||
|
assert len(observer._handler.filters) == 1
|
||||||
|
|
||||||
|
def test_register_autoresolve_optional_filters(self):
|
||||||
|
router = Router(use_builtin_filters=False)
|
||||||
|
observer = router.message
|
||||||
|
observer.bind_filter(MyFilter1)
|
||||||
|
observer.bind_filter(OptionalFilter)
|
||||||
|
observer.bind_filter(DefaultFilter)
|
||||||
|
|
||||||
|
assert observer.register(my_handler) == my_handler
|
||||||
|
assert isinstance(observer.handlers[0], HandlerObject)
|
||||||
|
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||||
|
assert len(observer.handlers[0].filters) == 2
|
||||||
|
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||||
|
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
|
||||||
|
|
||||||
|
observer.register(my_handler, test="ok")
|
||||||
|
assert isinstance(observer.handlers[1], HandlerObject)
|
||||||
|
assert len(observer.handlers[1].filters) == 3
|
||||||
|
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
|
||||||
|
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
|
||||||
|
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
|
||||||
|
|
||||||
|
observer.register(my_handler, test="ok", optional="ok")
|
||||||
|
assert isinstance(observer.handlers[2], HandlerObject)
|
||||||
|
assert len(observer.handlers[2].filters) == 3
|
||||||
|
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
|
||||||
|
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
|
||||||
|
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
|
||||||
|
|
||||||
def test_register(self):
|
def test_register(self):
|
||||||
router = Router(use_builtin_filters=False)
|
router = Router(use_builtin_filters=False)
|
||||||
|
|
@ -125,10 +211,11 @@ class TestTelegramEventObserver:
|
||||||
assert isinstance(observer.handlers[2], HandlerObject)
|
assert isinstance(observer.handlers[2], HandlerObject)
|
||||||
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
|
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
|
||||||
|
|
||||||
observer.register(my_handler, f, test="PASS")
|
f2 = MyFilter2(test="ok")
|
||||||
|
observer.register(my_handler, f2, test="PASS")
|
||||||
assert isinstance(observer.handlers[3], HandlerObject)
|
assert isinstance(observer.handlers[3], HandlerObject)
|
||||||
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
|
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
|
||||||
assert f in callbacks
|
assert f2 in callbacks
|
||||||
assert MyFilter1(test="PASS") in callbacks
|
assert MyFilter1(test="PASS") in callbacks
|
||||||
|
|
||||||
def test_register_decorator(self):
|
def test_register_decorator(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue