Bound filters resolving rework, filters with default argument (#727)

* bound filters resolving rework, filters with default argument

* bound filters resolving rework, filters with default argument

* Update 727.misc

* clarification of the comment about skipping filter

* fix data transfer from parent to included routers filters

* fix checking containing value in generator

* Update docs/dispatcher/filters/index.rst

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>

* Update 727.misc

* reformat

* better iterable types

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
darksidecat 2021-10-12 22:29:57 +03:00 committed by GitHub
parent f97367b3ee
commit 42cba8976f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 18 deletions

3
CHANGES/727.misc Normal file
View file

@ -0,0 +1,3 @@
Rework filters resolving:
* Automatically apply Bound Filters with default values to handlers
* Fix data transfer from parent to included routers filters

View file

@ -2,7 +2,18 @@ from __future__ import annotations
import functools import functools
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from pydantic import ValidationError from pydantic import ValidationError
@ -51,7 +62,7 @@ class TelegramEventObserver:
:param filters: positional filters :param filters: positional filters
:param bound_filters: keyword filters :param bound_filters: keyword filters
""" """
resolved_filters = self.resolve_filters(bound_filters) resolved_filters = self.resolve_filters(filters, bound_filters)
if self._handler.filters is None: if self._handler.filters is None:
self._handler.filters = [] self._handler.filters = []
self._handler.filters.extend( self._handler.filters.extend(
@ -77,7 +88,7 @@ class TelegramEventObserver:
""" """
registry: List[Type[BaseFilter]] = [] registry: List[Type[BaseFilter]] = []
for router in self.router.chain: for router in reversed(tuple(self.router.chain)):
observer = router.observers[self.event_name] observer = router.observers[self.event_name]
for filter_ in observer.filters: for filter_ in observer.filters:
@ -95,22 +106,46 @@ class TelegramEventObserver:
if outer: if outer:
middlewares.extend(self.outer_middlewares) middlewares.extend(self.outer_middlewares)
else: else:
for router in reversed(list(self.router.chain_head)): for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name] observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares) middlewares.extend(observer.middlewares)
return middlewares return middlewares
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: def resolve_filters(
self,
filters: Tuple[FilterType, ...],
full_config: Dict[str, Any],
ignore_default: bool = True,
) -> List[BaseFilter]:
""" """
Resolve keyword filters via filters factory Resolve keyword filters via filters factory
:param filters: positional filters
:param full_config: keyword arguments to initialize bounded filters for router/handler
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
""" """
filters: List[BaseFilter] = [] bound_filters: List[BaseFilter] = []
if not full_config:
return filters if ignore_default and not full_config:
return bound_filters
filter_types = set(type(f) for f in filters)
validation_errors = [] validation_errors = []
for bound_filter in self._resolve_filters_chain(): for bound_filter in self._resolve_filters_chain():
# skip filter if filter was used as positional filter:
if bound_filter in filter_types:
continue
# skip filter with no fields in full_config
if ignore_default:
full_config_keys = set(full_config.keys())
filter_fields = set(bound_filter.__fields__.keys())
if not full_config_keys.intersection(filter_fields):
continue
# Try to initialize filter. # Try to initialize filter.
try: try:
f = bound_filter(**full_config) f = bound_filter(**full_config)
@ -123,7 +158,7 @@ class TelegramEventObserver:
for key in f.__fields__: for key in f.__fields__:
full_config.pop(key, None) full_config.pop(key, None)
filters.append(f) bound_filters.append(f)
if full_config: if full_config:
possible_cases = [] possible_cases = []
@ -137,7 +172,7 @@ class TelegramEventObserver:
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
) )
return filters return bound_filters
def register( def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
@ -145,7 +180,7 @@ class TelegramEventObserver:
""" """
Register event handler Register event handler
""" """
resolved_filters = self.resolve_filters(bound_filters) resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
self.handlers.append( self.handlers.append(
HandlerObject( HandlerObject(
callback=callback, callback=callback,

View file

@ -75,3 +75,30 @@ For example if you need to make simple text filter:
Bound filters is always recursive propagates to the nested routers but will be available Bound filters is always recursive propagates to the nested routers but will be available
in nested routers only after attaching routers so that's mean you will need to in nested routers only after attaching routers so that's mean you will need to
include routers before registering handlers. include routers before registering handlers.
Resolving filters with default value
====================================
Bound Filters with only default arguments will be automatically applied with default values
to each handler in the router and nested routers to which this filter is bound.
For example, although we do not specify :code:`chat_type` in the handler filters,
but since the filter has a default value, the filter will be applied to the handler
with a default value :code:`private`:
.. code-block:: python
class ChatType(BaseFilter):
chat_type: str = "private"
async def __call__(self, message: Message , event_chat: Chat) -> bool:
if event_chat:
return event_chat.type == chat_type
else:
return False
router.message.bind_filter(ChatType)
@router.message()
async def my_handler(message: Message): ...

View file

@ -1,6 +1,6 @@
import datetime import datetime
import functools import functools
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
import pytest import pytest
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
pass pass
class OptionalFilter(BaseFilter):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(BaseFilter):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class TestTelegramEventObserver: class TestTelegramEventObserver:
def test_bind_filter(self): def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test") event_observer = TelegramEventObserver(Router(), "test")
@ -85,26 +99,98 @@ class TestTelegramEventObserver:
assert MyFilter2 in filters_chain3 assert MyFilter2 in filters_chain3
assert MyFilter3 in filters_chain3 assert MyFilter3 in filters_chain3
async def test_resolve_filters_data_from_parent_router(self):
class FilterSet(BaseFilter):
set_filter: bool
async def __call__(self, message: Message) -> dict:
return {"test": "hello world"}
class FilterGet(BaseFilter):
get_filter: bool
async def __call__(self, message: Message, **data) -> bool:
assert "test" in data
return True
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router1.include_router(router2)
router1.message.bind_filter(FilterSet)
router2.message.bind_filter(FilterGet)
@router2.message(set_filter=True, get_filter=True)
def handler_test(msg: Message, test: str):
assert test == "hello world"
await router1.propagate_event(
"message",
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
)
def test_resolve_filters(self): def test_resolve_filters(self):
router = Router(use_builtin_filters=False) router = Router(use_builtin_filters=False)
observer = router.message observer = router.message
observer.bind_filter(MyFilter1) observer.bind_filter(MyFilter1)
resolved = observer.resolve_filters({"test": "PASS"}) resolved = observer.resolve_filters((), {"test": "PASS"})
assert isinstance(resolved, list) assert isinstance(resolved, list)
assert any(isinstance(item, MyFilter1) for item in resolved) assert any(isinstance(item, MyFilter1) for item in resolved)
# Unknown filter # Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"@bad": "very"}) assert observer.resolve_filters((), {"@bad": "very"})
# Unknown filter # Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"test": "ok", "@bad": "very"}) assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
# Bad argument type # Bad argument type
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"): with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
assert observer.resolve_filters({"test": ...}) assert observer.resolve_filters((), {"test": ...})
# Disallow same filter using
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
def test_dont_autoresolve_optional_filters_for_router(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
observer.filter(test="test")
assert len(observer._handler.filters) == 1
def test_register_autoresolve_optional_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
assert observer.register(my_handler) == my_handler
assert isinstance(observer.handlers[0], HandlerObject)
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert len(observer.handlers[0].filters) == 2
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
observer.register(my_handler, test="ok")
assert isinstance(observer.handlers[1], HandlerObject)
assert len(observer.handlers[1].filters) == 3
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
observer.register(my_handler, test="ok", optional="ok")
assert isinstance(observer.handlers[2], HandlerObject)
assert len(observer.handlers[2].filters) == 3
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
def test_register(self): def test_register(self):
router = Router(use_builtin_filters=False) router = Router(use_builtin_filters=False)
@ -125,10 +211,11 @@ class TestTelegramEventObserver:
assert isinstance(observer.handlers[2], HandlerObject) assert isinstance(observer.handlers[2], HandlerObject)
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters) assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
observer.register(my_handler, f, test="PASS") f2 = MyFilter2(test="ok")
observer.register(my_handler, f2, test="PASS")
assert isinstance(observer.handlers[3], HandlerObject) assert isinstance(observer.handlers[3], HandlerObject)
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters] callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
assert f in callbacks assert f2 in callbacks
assert MyFilter1(test="PASS") in callbacks assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self): def test_register_decorator(self):