Added possibility to combine filters or invert result (#895)

* Added possibility to combine filters or invert result
This commit is contained in:
Alex Root Junior 2022-04-24 04:19:19 +03:00 committed by GitHub
parent 7bfc941a1e
commit 4fb77a3a2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 184 additions and 40 deletions

7
CHANGES/894.feature.rst Normal file
View file

@ -0,0 +1,7 @@
Added possibility to combine filters or invert result
Example:
.. code-block:: python
Text(text="demo") | Command(commands=["demo"])
MyFilter() & AnotherFilter()
~StateFilter(state='my-state')

View file

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Callable, List
from .handler import CallbackType, HandlerObject, HandlerType
from .handler import CallbackType, HandlerObject
class EventObserver:
@ -26,7 +26,7 @@ class EventObserver:
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType) -> None:
def register(self, callback: CallbackType) -> None:
"""
Register callback with filters
"""

View file

@ -3,24 +3,19 @@ import contextvars
import inspect
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple
from magic_filter import MagicFilter
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.flags.getter import extract_flags_from_object
from aiogram.dispatcher.handler.base import BaseHandler
CallbackType = Callable[..., Awaitable[Any]]
SyncFilter = Callable[..., Any]
AsyncFilter = Callable[..., Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter, MagicFilter]
HandlerType = Union[FilterType, Type[BaseHandler]]
CallbackType = Callable[..., Any]
@dataclass
class CallableMixin:
callback: HandlerType
callback: CallbackType
awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False)
@ -50,7 +45,7 @@ class CallableMixin:
@dataclass
class FilterObject(CallableMixin):
callback: FilterType
callback: CallbackType
def __post_init__(self) -> None:
# TODO: Make possibility to extract and explain magic from filter object.
@ -63,7 +58,7 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
callback: HandlerType
callback: CallbackType
filters: Optional[List[FilterObject]] = None
flags: Dict[str, Any] = field(default_factory=dict)

View file

@ -12,7 +12,7 @@ from ...exceptions import FiltersResolveError
from ...types import TelegramObject
from ..filters.base import BaseFilter
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
from .handler import CallbackType, FilterObject, HandlerObject
if TYPE_CHECKING:
from aiogram.dispatcher.router import Router
@ -40,7 +40,7 @@ class TelegramEventObserver:
# with dummy callback which never will be used
self._handler = HandlerObject(callback=lambda: True, filters=[])
def filter(self, *filters: FilterType, **bound_filters: Any) -> None:
def filter(self, *filters: CallbackType, **bound_filters: Any) -> None:
"""
Register filter for all handlers of this event observer
@ -51,7 +51,13 @@ class TelegramEventObserver:
if self._handler.filters is None:
self._handler.filters = []
self._handler.filters.extend(
[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)]
[
FilterObject(filter_) # type: ignore
for filter_ in chain(
resolved_filters,
filters,
)
]
)
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
@ -96,7 +102,7 @@ class TelegramEventObserver:
def resolve_filters(
self,
filters: Tuple[FilterType, ...],
filters: Tuple[CallbackType, ...],
full_config: Dict[str, Any],
ignore_default: bool = True,
) -> List[BaseFilter]:
@ -158,11 +164,11 @@ class TelegramEventObserver:
def register(
self,
callback: HandlerType,
*filters: FilterType,
callback: CallbackType,
*filters: CallbackType,
flags: Optional[Dict[str, Any]] = None,
**bound_filters: Any,
) -> HandlerType:
) -> CallbackType:
"""
Register event handler
"""
@ -174,7 +180,13 @@ class TelegramEventObserver:
self.handlers.append(
HandlerObject(
callback=callback,
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
filters=[
FilterObject(filter_) # type: ignore
for filter_ in chain(
resolved_filters,
filters,
)
],
flags=flags,
)
)
@ -216,7 +228,7 @@ class TelegramEventObserver:
return UNHANDLED
def __call__(
self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
self, *args: CallbackType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers

View file

@ -19,6 +19,7 @@ from .chat_member_updated import (
from .command import Command, CommandObject
from .content_types import ContentTypesFilter
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
from .logic import and_f, invert_f, or_f
from .magic_data import MagicData
from .state import StateFilter
from .text import Text
@ -47,6 +48,9 @@ __all__ = (
"IS_NOT_MEMBER",
"JOIN_TRANSITION",
"LEAVE_TRANSITION",
"and_f",
"or_f",
"invert_f",
)
_ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,)

View file

@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
from pydantic import BaseModel
from aiogram.dispatcher.filters.logic import _LogicFilter
class BaseFilter(ABC, BaseModel):
class BaseFilter(BaseModel, ABC, _LogicFilter):
"""
If you want to register own filters like builtin filters you will need to write subclass
of this class with overriding the :code:`__call__`

View file

@ -0,0 +1,87 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
if TYPE_CHECKING:
from aiogram.dispatcher.event.handler import CallbackType, FilterObject
class _LogicFilter:
__call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]]
def __and__(self, other: "CallbackType") -> "_AndFilter":
return and_f(self, other)
def __or__(self, other: "CallbackType") -> "_OrFilter":
return or_f(self, other)
def __invert__(self) -> "_InvertFilter":
return invert_f(self)
def __await__(self): # type: ignore # pragma: no cover
# Is needed only for inspection and this method is never be called
return self.__call__
class _InvertFilter(_LogicFilter):
__slots__ = ("target",)
def __init__(self, target: "FilterObject") -> None:
self.target = target
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return not bool(await self.target.call(*args, **kwargs))
class _AndFilter(_LogicFilter):
__slots__ = ("targets",)
def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
final_result = {}
for target in self.targets:
result = await target.call(*args, **kwargs)
if not result:
return False
if isinstance(result, dict):
final_result.update(result)
if final_result:
return final_result
return True
class _OrFilter(_LogicFilter):
__slots__ = ("targets",)
def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
for target in self.targets:
result = await target.call(*args, **kwargs)
if not result:
continue
if isinstance(result, dict):
return result
return bool(result)
return False
def and_f(target1: "CallbackType", target2: "CallbackType") -> _AndFilter:
from aiogram.dispatcher.event.handler import FilterObject
return _AndFilter(FilterObject(target1), FilterObject(target2))
def or_f(target1: "CallbackType", target2: "CallbackType") -> _OrFilter:
from aiogram.dispatcher.event.handler import FilterObject
return _OrFilter(FilterObject(target1), FilterObject(target2))
def invert_f(target: "CallbackType") -> _InvertFilter:
from aiogram.dispatcher.event.handler import FilterObject
return _InvertFilter(FilterObject(target))

View file

@ -2,7 +2,7 @@ import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
from aiogram.dispatcher.event.handler import HandlerType
from aiogram.dispatcher.event.handler import CallbackType
from aiogram.types import TelegramObject
@ -49,7 +49,7 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
@staticmethod
def wrap_middlewares(
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType
) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler)
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:

View file

@ -8,7 +8,7 @@ from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
from .event.handler import HandlerType
from .event.handler import CallbackType
from .event.telegram import TelegramEventObserver
from .filters import BUILTIN_FILTERS
@ -396,7 +396,7 @@ class Router:
)
return self.errors
def register_message(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_message(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_message(...)` is deprecated and will be removed in version 3.2 "
"use `Router.message.register(...)`",
@ -405,7 +405,7 @@ class Router:
)
return self.message.register(*args, **kwargs)
def register_edited_message(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_edited_message(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_edited_message(...)` is deprecated and will be removed in version 3.2 "
"use `Router.edited_message.register(...)`",
@ -414,7 +414,7 @@ class Router:
)
return self.edited_message.register(*args, **kwargs)
def register_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_channel_post(...)` is deprecated and will be removed in version 3.2 "
"use `Router.channel_post.register(...)`",
@ -423,7 +423,7 @@ class Router:
)
return self.channel_post.register(*args, **kwargs)
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_edited_channel_post(...)` is deprecated and will be removed in version 3.2 "
"use `Router.edited_channel_post.register(...)`",
@ -432,7 +432,7 @@ class Router:
)
return self.edited_channel_post.register(*args, **kwargs)
def register_inline_query(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_inline_query(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_inline_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.inline_query.register(...)`",
@ -441,7 +441,7 @@ class Router:
)
return self.inline_query.register(*args, **kwargs)
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_chosen_inline_result(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chosen_inline_result.register(...)`",
@ -450,7 +450,7 @@ class Router:
)
return self.chosen_inline_result.register(*args, **kwargs)
def register_callback_query(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_callback_query(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_callback_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.callback_query.register(...)`",
@ -459,7 +459,7 @@ class Router:
)
return self.callback_query.register(*args, **kwargs)
def register_shipping_query(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_shipping_query(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_shipping_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.shipping_query.register(...)`",
@ -468,7 +468,7 @@ class Router:
)
return self.shipping_query.register(*args, **kwargs)
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_pre_checkout_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.pre_checkout_query.register(...)`",
@ -477,7 +477,7 @@ class Router:
)
return self.pre_checkout_query.register(*args, **kwargs)
def register_poll(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_poll(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_poll(...)` is deprecated and will be removed in version 3.2 "
"use `Router.poll.register(...)`",
@ -486,7 +486,7 @@ class Router:
)
return self.poll.register(*args, **kwargs)
def register_poll_answer(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_poll_answer(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_poll_answer(...)` is deprecated and will be removed in version 3.2 "
"use `Router.poll_answer.register(...)`",
@ -495,7 +495,7 @@ class Router:
)
return self.poll_answer.register(*args, **kwargs)
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_my_chat_member(...)` is deprecated and will be removed in version 3.2 "
"use `Router.my_chat_member.register(...)`",
@ -504,7 +504,7 @@ class Router:
)
return self.my_chat_member.register(*args, **kwargs)
def register_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_chat_member(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_member.register(...)`",
@ -513,7 +513,7 @@ class Router:
)
return self.chat_member.register(*args, **kwargs)
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_chat_join_request(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_join_request.register(...)`",
@ -522,7 +522,7 @@ class Router:
)
return self.chat_join_request.register(*args, **kwargs)
def register_errors(self, *args: Any, **kwargs: Any) -> HandlerType:
def register_errors(self, *args: Any, **kwargs: Any) -> CallbackType:
warnings.warn(
"`Router.register_errors(...)` is deprecated and will be removed in version 3.2 "
"use `Router.errors.register(...)`",

View file

@ -0,0 +1,37 @@
import pytest
from aiogram.dispatcher.filters import Text, and_f, invert_f, or_f
from aiogram.dispatcher.filters.logic import _AndFilter, _InvertFilter, _OrFilter
class TestLogic:
@pytest.mark.parametrize(
"obj,case,result",
[
[True, and_f(lambda t: t is True, lambda t: t is True), True],
[True, and_f(lambda t: t is True, lambda t: t is False), False],
[True, and_f(lambda t: t is False, lambda t: t is False), False],
[True, and_f(lambda t: {"t": t}, lambda t: t is False), False],
[True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}],
[True, or_f(lambda t: t is True, lambda t: t is True), True],
[True, or_f(lambda t: t is True, lambda t: t is False), True],
[True, or_f(lambda t: t is False, lambda t: t is False), False],
[True, or_f(lambda t: t is False, lambda t: t is True), True],
[True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}],
[True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}],
[True, invert_f(lambda t: t is False), True],
],
)
async def test_logic(self, obj, case, result):
assert await case(obj) == result
@pytest.mark.parametrize(
"case,type_",
[
[Text(text="test") | Text(text="test"), _OrFilter],
[Text(text="test") & Text(text="test"), _AndFilter],
[~Text(text="test"), _InvertFilter],
],
)
def test_dunder_methods(self, case, type_):
assert isinstance(case, type_)