mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 07:50:32 +00:00
Added possibility to combine filters or invert result (#895)
* Added possibility to combine filters or invert result
This commit is contained in:
parent
7bfc941a1e
commit
4fb77a3a2a
10 changed files with 184 additions and 40 deletions
7
CHANGES/894.feature.rst
Normal file
7
CHANGES/894.feature.rst
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
Added possibility to combine filters or invert result
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Text(text="demo") | Command(commands=["demo"])
|
||||
MyFilter() & AnotherFilter()
|
||||
~StateFilter(state='my-state')
|
||||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from .handler import CallbackType, HandlerObject, HandlerType
|
||||
from .handler import CallbackType, HandlerObject
|
||||
|
||||
|
||||
class EventObserver:
|
||||
|
|
@ -26,7 +26,7 @@ class EventObserver:
|
|||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType) -> None:
|
||||
def register(self, callback: CallbackType) -> None:
|
||||
"""
|
||||
Register callback with filters
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,24 +3,19 @@ import contextvars
|
|||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from magic_filter import MagicFilter
|
||||
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.flags.getter import extract_flags_from_object
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
|
||||
CallbackType = Callable[..., Awaitable[Any]]
|
||||
SyncFilter = Callable[..., Any]
|
||||
AsyncFilter = Callable[..., Awaitable[Any]]
|
||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter, MagicFilter]
|
||||
HandlerType = Union[FilterType, Type[BaseHandler]]
|
||||
CallbackType = Callable[..., Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallableMixin:
|
||||
callback: HandlerType
|
||||
callback: CallbackType
|
||||
awaitable: bool = field(init=False)
|
||||
spec: inspect.FullArgSpec = field(init=False)
|
||||
|
||||
|
|
@ -50,7 +45,7 @@ class CallableMixin:
|
|||
|
||||
@dataclass
|
||||
class FilterObject(CallableMixin):
|
||||
callback: FilterType
|
||||
callback: CallbackType
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# TODO: Make possibility to extract and explain magic from filter object.
|
||||
|
|
@ -63,7 +58,7 @@ class FilterObject(CallableMixin):
|
|||
|
||||
@dataclass
|
||||
class HandlerObject(CallableMixin):
|
||||
callback: HandlerType
|
||||
callback: CallbackType
|
||||
filters: Optional[List[FilterObject]] = None
|
||||
flags: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ...exceptions import FiltersResolveError
|
|||
from ...types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
from .handler import CallbackType, FilterObject, HandlerObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
|
@ -40,7 +40,7 @@ class TelegramEventObserver:
|
|||
# with dummy callback which never will be used
|
||||
self._handler = HandlerObject(callback=lambda: True, filters=[])
|
||||
|
||||
def filter(self, *filters: FilterType, **bound_filters: Any) -> None:
|
||||
def filter(self, *filters: CallbackType, **bound_filters: Any) -> None:
|
||||
"""
|
||||
Register filter for all handlers of this event observer
|
||||
|
||||
|
|
@ -51,7 +51,13 @@ class TelegramEventObserver:
|
|||
if self._handler.filters is None:
|
||||
self._handler.filters = []
|
||||
self._handler.filters.extend(
|
||||
[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)]
|
||||
[
|
||||
FilterObject(filter_) # type: ignore
|
||||
for filter_ in chain(
|
||||
resolved_filters,
|
||||
filters,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
|
|
@ -96,7 +102,7 @@ class TelegramEventObserver:
|
|||
|
||||
def resolve_filters(
|
||||
self,
|
||||
filters: Tuple[FilterType, ...],
|
||||
filters: Tuple[CallbackType, ...],
|
||||
full_config: Dict[str, Any],
|
||||
ignore_default: bool = True,
|
||||
) -> List[BaseFilter]:
|
||||
|
|
@ -158,11 +164,11 @@ class TelegramEventObserver:
|
|||
|
||||
def register(
|
||||
self,
|
||||
callback: HandlerType,
|
||||
*filters: FilterType,
|
||||
callback: CallbackType,
|
||||
*filters: CallbackType,
|
||||
flags: Optional[Dict[str, Any]] = None,
|
||||
**bound_filters: Any,
|
||||
) -> HandlerType:
|
||||
) -> CallbackType:
|
||||
"""
|
||||
Register event handler
|
||||
"""
|
||||
|
|
@ -174,7 +180,13 @@ class TelegramEventObserver:
|
|||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
||||
filters=[
|
||||
FilterObject(filter_) # type: ignore
|
||||
for filter_ in chain(
|
||||
resolved_filters,
|
||||
filters,
|
||||
)
|
||||
],
|
||||
flags=flags,
|
||||
)
|
||||
)
|
||||
|
|
@ -216,7 +228,7 @@ class TelegramEventObserver:
|
|||
return UNHANDLED
|
||||
|
||||
def __call__(
|
||||
self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
|
||||
self, *args: CallbackType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
|
||||
) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .chat_member_updated import (
|
|||
from .command import Command, CommandObject
|
||||
from .content_types import ContentTypesFilter
|
||||
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from .logic import and_f, invert_f, or_f
|
||||
from .magic_data import MagicData
|
||||
from .state import StateFilter
|
||||
from .text import Text
|
||||
|
|
@ -47,6 +48,9 @@ __all__ = (
|
|||
"IS_NOT_MEMBER",
|
||||
"JOIN_TRANSITION",
|
||||
"LEAVE_TRANSITION",
|
||||
"and_f",
|
||||
"or_f",
|
||||
"invert_f",
|
||||
)
|
||||
|
||||
_ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiogram.dispatcher.filters.logic import _LogicFilter
|
||||
|
||||
class BaseFilter(ABC, BaseModel):
|
||||
|
||||
class BaseFilter(BaseModel, ABC, _LogicFilter):
|
||||
"""
|
||||
If you want to register own filters like builtin filters you will need to write subclass
|
||||
of this class with overriding the :code:`__call__`
|
||||
|
|
|
|||
87
aiogram/dispatcher/filters/logic.py
Normal file
87
aiogram/dispatcher/filters/logic.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.event.handler import CallbackType, FilterObject
|
||||
|
||||
|
||||
class _LogicFilter:
|
||||
__call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]]
|
||||
|
||||
def __and__(self, other: "CallbackType") -> "_AndFilter":
|
||||
return and_f(self, other)
|
||||
|
||||
def __or__(self, other: "CallbackType") -> "_OrFilter":
|
||||
return or_f(self, other)
|
||||
|
||||
def __invert__(self) -> "_InvertFilter":
|
||||
return invert_f(self)
|
||||
|
||||
def __await__(self): # type: ignore # pragma: no cover
|
||||
# Is needed only for inspection and this method is never be called
|
||||
return self.__call__
|
||||
|
||||
|
||||
class _InvertFilter(_LogicFilter):
|
||||
__slots__ = ("target",)
|
||||
|
||||
def __init__(self, target: "FilterObject") -> None:
|
||||
self.target = target
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return not bool(await self.target.call(*args, **kwargs))
|
||||
|
||||
|
||||
class _AndFilter(_LogicFilter):
|
||||
__slots__ = ("targets",)
|
||||
|
||||
def __init__(self, *targets: "FilterObject") -> None:
|
||||
self.targets = targets
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
final_result = {}
|
||||
|
||||
for target in self.targets:
|
||||
result = await target.call(*args, **kwargs)
|
||||
if not result:
|
||||
return False
|
||||
if isinstance(result, dict):
|
||||
final_result.update(result)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
return True
|
||||
|
||||
|
||||
class _OrFilter(_LogicFilter):
|
||||
__slots__ = ("targets",)
|
||||
|
||||
def __init__(self, *targets: "FilterObject") -> None:
|
||||
self.targets = targets
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
for target in self.targets:
|
||||
result = await target.call(*args, **kwargs)
|
||||
if not result:
|
||||
continue
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return bool(result)
|
||||
return False
|
||||
|
||||
|
||||
def and_f(target1: "CallbackType", target2: "CallbackType") -> _AndFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _AndFilter(FilterObject(target1), FilterObject(target2))
|
||||
|
||||
|
||||
def or_f(target1: "CallbackType", target2: "CallbackType") -> _OrFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _OrFilter(FilterObject(target1), FilterObject(target2))
|
||||
|
||||
|
||||
def invert_f(target: "CallbackType") -> _InvertFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _InvertFilter(FilterObject(target))
|
||||
|
|
@ -2,7 +2,7 @@ import functools
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
|
||||
|
||||
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
|
||||
from aiogram.dispatcher.event.handler import HandlerType
|
||||
from aiogram.dispatcher.event.handler import CallbackType
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
|||
|
||||
@staticmethod
|
||||
def wrap_middlewares(
|
||||
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
|
||||
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType
|
||||
) -> NextMiddlewareType[MiddlewareEventType]:
|
||||
@functools.wraps(handler)
|
||||
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from ..utils.imports import import_module
|
|||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
from .event.event import EventObserver
|
||||
from .event.handler import HandlerType
|
||||
from .event.handler import CallbackType
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
|
||||
|
|
@ -396,7 +396,7 @@ class Router:
|
|||
)
|
||||
return self.errors
|
||||
|
||||
def register_message(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_message(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_message(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.message.register(...)`",
|
||||
|
|
@ -405,7 +405,7 @@ class Router:
|
|||
)
|
||||
return self.message.register(*args, **kwargs)
|
||||
|
||||
def register_edited_message(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_edited_message(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_edited_message(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.edited_message.register(...)`",
|
||||
|
|
@ -414,7 +414,7 @@ class Router:
|
|||
)
|
||||
return self.edited_message.register(*args, **kwargs)
|
||||
|
||||
def register_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_channel_post(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.channel_post.register(...)`",
|
||||
|
|
@ -423,7 +423,7 @@ class Router:
|
|||
)
|
||||
return self.channel_post.register(*args, **kwargs)
|
||||
|
||||
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_edited_channel_post(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.edited_channel_post.register(...)`",
|
||||
|
|
@ -432,7 +432,7 @@ class Router:
|
|||
)
|
||||
return self.edited_channel_post.register(*args, **kwargs)
|
||||
|
||||
def register_inline_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_inline_query(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_inline_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.inline_query.register(...)`",
|
||||
|
|
@ -441,7 +441,7 @@ class Router:
|
|||
)
|
||||
return self.inline_query.register(*args, **kwargs)
|
||||
|
||||
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_chosen_inline_result(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chosen_inline_result.register(...)`",
|
||||
|
|
@ -450,7 +450,7 @@ class Router:
|
|||
)
|
||||
return self.chosen_inline_result.register(*args, **kwargs)
|
||||
|
||||
def register_callback_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_callback_query(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_callback_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.callback_query.register(...)`",
|
||||
|
|
@ -459,7 +459,7 @@ class Router:
|
|||
)
|
||||
return self.callback_query.register(*args, **kwargs)
|
||||
|
||||
def register_shipping_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_shipping_query(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_shipping_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.shipping_query.register(...)`",
|
||||
|
|
@ -468,7 +468,7 @@ class Router:
|
|||
)
|
||||
return self.shipping_query.register(*args, **kwargs)
|
||||
|
||||
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_pre_checkout_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.pre_checkout_query.register(...)`",
|
||||
|
|
@ -477,7 +477,7 @@ class Router:
|
|||
)
|
||||
return self.pre_checkout_query.register(*args, **kwargs)
|
||||
|
||||
def register_poll(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_poll(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_poll(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.poll.register(...)`",
|
||||
|
|
@ -486,7 +486,7 @@ class Router:
|
|||
)
|
||||
return self.poll.register(*args, **kwargs)
|
||||
|
||||
def register_poll_answer(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_poll_answer(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_poll_answer(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.poll_answer.register(...)`",
|
||||
|
|
@ -495,7 +495,7 @@ class Router:
|
|||
)
|
||||
return self.poll_answer.register(*args, **kwargs)
|
||||
|
||||
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_my_chat_member(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.my_chat_member.register(...)`",
|
||||
|
|
@ -504,7 +504,7 @@ class Router:
|
|||
)
|
||||
return self.my_chat_member.register(*args, **kwargs)
|
||||
|
||||
def register_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_chat_member(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_chat_member(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_member.register(...)`",
|
||||
|
|
@ -513,7 +513,7 @@ class Router:
|
|||
)
|
||||
return self.chat_member.register(*args, **kwargs)
|
||||
|
||||
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_chat_join_request(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_join_request.register(...)`",
|
||||
|
|
@ -522,7 +522,7 @@ class Router:
|
|||
)
|
||||
return self.chat_join_request.register(*args, **kwargs)
|
||||
|
||||
def register_errors(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
def register_errors(self, *args: Any, **kwargs: Any) -> CallbackType:
|
||||
warnings.warn(
|
||||
"`Router.register_errors(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.errors.register(...)`",
|
||||
|
|
|
|||
37
tests/test_dispatcher/test_filters/test_logic.py
Normal file
37
tests/test_dispatcher/test_filters/test_logic.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.dispatcher.filters import Text, and_f, invert_f, or_f
|
||||
from aiogram.dispatcher.filters.logic import _AndFilter, _InvertFilter, _OrFilter
|
||||
|
||||
|
||||
class TestLogic:
|
||||
@pytest.mark.parametrize(
|
||||
"obj,case,result",
|
||||
[
|
||||
[True, and_f(lambda t: t is True, lambda t: t is True), True],
|
||||
[True, and_f(lambda t: t is True, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: t is False, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: {"t": t}, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}],
|
||||
[True, or_f(lambda t: t is True, lambda t: t is True), True],
|
||||
[True, or_f(lambda t: t is True, lambda t: t is False), True],
|
||||
[True, or_f(lambda t: t is False, lambda t: t is False), False],
|
||||
[True, or_f(lambda t: t is False, lambda t: t is True), True],
|
||||
[True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}],
|
||||
[True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}],
|
||||
[True, invert_f(lambda t: t is False), True],
|
||||
],
|
||||
)
|
||||
async def test_logic(self, obj, case, result):
|
||||
assert await case(obj) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case,type_",
|
||||
[
|
||||
[Text(text="test") | Text(text="test"), _OrFilter],
|
||||
[Text(text="test") & Text(text="test"), _AndFilter],
|
||||
[~Text(text="test"), _InvertFilter],
|
||||
],
|
||||
)
|
||||
def test_dunder_methods(self, case, type_):
|
||||
assert isinstance(case, type_)
|
||||
Loading…
Add table
Add a link
Reference in a new issue