mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Bound filters resolving rework, filters with default argument (#727)
* bound filters resolving rework, filters with default argument * bound filters resolving rework, filters with default argument * Update 727.misc * clarification of the comment about skipping filter * fix data transfer from parent to included routers filters * fix checking containing value in generator * Update docs/dispatcher/filters/index.rst Co-authored-by: Alex Root Junior <jroot.junior@gmail.com> * Update 727.misc * reformat * better iterable types Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
parent
f97367b3ee
commit
42cba8976f
4 changed files with 170 additions and 18 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
|
|||
pass
|
||||
|
||||
|
||||
class OptionalFilter(BaseFilter):
|
||||
optional: Optional[str]
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class DefaultFilter(BaseFilter):
|
||||
default: str = "Default"
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
def test_bind_filter(self):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
|
@ -85,26 +99,98 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter2 in filters_chain3
|
||||
assert MyFilter3 in filters_chain3
|
||||
|
||||
async def test_resolve_filters_data_from_parent_router(self):
|
||||
class FilterSet(BaseFilter):
|
||||
set_filter: bool
|
||||
|
||||
async def __call__(self, message: Message) -> dict:
|
||||
return {"test": "hello world"}
|
||||
|
||||
class FilterGet(BaseFilter):
|
||||
get_filter: bool
|
||||
|
||||
async def __call__(self, message: Message, **data) -> bool:
|
||||
assert "test" in data
|
||||
return True
|
||||
|
||||
router1 = Router(use_builtin_filters=False)
|
||||
router2 = Router(use_builtin_filters=False)
|
||||
router1.include_router(router2)
|
||||
|
||||
router1.message.bind_filter(FilterSet)
|
||||
router2.message.bind_filter(FilterGet)
|
||||
|
||||
@router2.message(set_filter=True, get_filter=True)
|
||||
def handler_test(msg: Message, test: str):
|
||||
assert test == "hello world"
|
||||
|
||||
await router1.propagate_event(
|
||||
"message",
|
||||
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
|
||||
)
|
||||
|
||||
def test_resolve_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
resolved = observer.resolve_filters({"test": "PASS"})
|
||||
resolved = observer.resolve_filters((), {"test": "PASS"})
|
||||
assert isinstance(resolved, list)
|
||||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"@bad": "very"})
|
||||
assert observer.resolve_filters((), {"@bad": "very"})
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
||||
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
|
||||
|
||||
# Bad argument type
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
assert observer.resolve_filters({"test": ...})
|
||||
assert observer.resolve_filters((), {"test": ...})
|
||||
|
||||
# Disallow same filter using
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
|
||||
|
||||
def test_dont_autoresolve_optional_filters_for_router(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
observer.filter(test="test")
|
||||
assert len(observer._handler.filters) == 1
|
||||
|
||||
def test_register_autoresolve_optional_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
assert observer.register(my_handler) == my_handler
|
||||
assert isinstance(observer.handlers[0], HandlerObject)
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert len(observer.handlers[0].filters) == 2
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok")
|
||||
assert isinstance(observer.handlers[1], HandlerObject)
|
||||
assert len(observer.handlers[1].filters) == 3
|
||||
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok", optional="ok")
|
||||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert len(observer.handlers[2].filters) == 3
|
||||
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
|
||||
|
||||
def test_register(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
|
|
@ -125,10 +211,11 @@ class TestTelegramEventObserver:
|
|||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
|
||||
|
||||
observer.register(my_handler, f, test="PASS")
|
||||
f2 = MyFilter2(test="ok")
|
||||
observer.register(my_handler, f2, test="PASS")
|
||||
assert isinstance(observer.handlers[3], HandlerObject)
|
||||
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
|
||||
assert f in callbacks
|
||||
assert f2 in callbacks
|
||||
assert MyFilter1(test="PASS") in callbacks
|
||||
|
||||
def test_register_decorator(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue