Bound filters resolving rework, filters with default argument (#727)

* bound filters resolving rework, filters with default argument

* bound filters resolving rework, filters with default argument

* Update 727.misc

* clarification of the comment about skipping filter

* fix data transfer from parent to included routers filters

* fix checking containing value in generator

* Update docs/dispatcher/filters/index.rst

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>

* Update 727.misc

* reformat

* better iterable types

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
darksidecat 2021-10-12 22:29:57 +03:00 committed by GitHub
parent f97367b3ee
commit 42cba8976f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 18 deletions

View file

@ -1,6 +1,6 @@
import datetime
import functools
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
import pytest
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
pass
class OptionalFilter(BaseFilter):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(BaseFilter):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
@ -85,26 +99,98 @@ class TestTelegramEventObserver:
assert MyFilter2 in filters_chain3
assert MyFilter3 in filters_chain3
async def test_resolve_filters_data_from_parent_router(self):
class FilterSet(BaseFilter):
set_filter: bool
async def __call__(self, message: Message) -> dict:
return {"test": "hello world"}
class FilterGet(BaseFilter):
get_filter: bool
async def __call__(self, message: Message, **data) -> bool:
assert "test" in data
return True
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router1.include_router(router2)
router1.message.bind_filter(FilterSet)
router2.message.bind_filter(FilterGet)
@router2.message(set_filter=True, get_filter=True)
def handler_test(msg: Message, test: str):
assert test == "hello world"
await router1.propagate_event(
"message",
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
)
def test_resolve_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
resolved = observer.resolve_filters({"test": "PASS"})
resolved = observer.resolve_filters((), {"test": "PASS"})
assert isinstance(resolved, list)
assert any(isinstance(item, MyFilter1) for item in resolved)
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"@bad": "very"})
assert observer.resolve_filters((), {"@bad": "very"})
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
# Bad argument type
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
assert observer.resolve_filters({"test": ...})
assert observer.resolve_filters((), {"test": ...})
# Disallow same filter using
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
def test_dont_autoresolve_optional_filters_for_router(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
observer.filter(test="test")
assert len(observer._handler.filters) == 1
def test_register_autoresolve_optional_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
assert observer.register(my_handler) == my_handler
assert isinstance(observer.handlers[0], HandlerObject)
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert len(observer.handlers[0].filters) == 2
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
observer.register(my_handler, test="ok")
assert isinstance(observer.handlers[1], HandlerObject)
assert len(observer.handlers[1].filters) == 3
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
observer.register(my_handler, test="ok", optional="ok")
assert isinstance(observer.handlers[2], HandlerObject)
assert len(observer.handlers[2].filters) == 3
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
def test_register(self):
router = Router(use_builtin_filters=False)
@ -125,10 +211,11 @@ class TestTelegramEventObserver:
assert isinstance(observer.handlers[2], HandlerObject)
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
observer.register(my_handler, f, test="PASS")
f2 = MyFilter2(test="ok")
observer.register(my_handler, f2, test="PASS")
assert isinstance(observer.handlers[3], HandlerObject)
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
assert f in callbacks
assert f2 in callbacks
assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self):