From 40b6a61e7020471efa862913dc50989e415622ec Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Fri, 29 Nov 2019 23:16:11 +0200 Subject: [PATCH] Add text filter and mechanism for registering builtin filters --- aiogram/dispatcher/filters/__init__.py | 20 ++ aiogram/dispatcher/filters/text.py | 80 ++++++ aiogram/dispatcher/router.py | 9 +- .../filters/{base_filter.md => base.md} | 0 docs/dispatcher/filters/text.md | 47 ++++ docs/dispatcher/observer.md | 2 +- mkdocs.yml | 3 +- .../test_event/test_observer.py | 12 +- .../test_dispatcher/test_filters/test_text.py | 233 ++++++++++++++++++ 9 files changed, 398 insertions(+), 8 deletions(-) create mode 100644 aiogram/dispatcher/filters/text.py rename docs/dispatcher/filters/{base_filter.md => base.md} (100%) create mode 100644 docs/dispatcher/filters/text.md create mode 100644 tests/test_dispatcher/test_filters/test_text.py diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index e69de29b..d0205309 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -0,0 +1,20 @@ +from typing import Dict, Tuple, Union + +from .base import BaseFilter +from .text import Text + +__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text") + +BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = { + "update": (), + "message": (Text,), + "edited_message": (Text,), + "channel_post": (Text,), + "edited_channel_post": (Text,), + "inline_query": (Text,), + "chosen_inline_result": (), + "callback_query": (Text,), + "shipping_query": (), + "pre_checkout_query": (), + "poll": (), +} diff --git a/aiogram/dispatcher/filters/text.py b/aiogram/dispatcher/filters/text.py new file mode 100644 index 00000000..2de77319 --- /dev/null +++ b/aiogram/dispatcher/filters/text.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from pydantic import root_validator + +from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll +from aiogram.dispatcher.filters import BaseFilter + + +class Text(BaseFilter): + text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text_ignore_case: bool = False + + @root_validator + def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # Validate that only one text filter type is presented + used_args = set( + key for key, value in values.items() if key != "text_ignore_case" and value is not None + ) + if len(used_args) < 1: + raise ValueError( + "Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}" + ) + if len(used_args) > 1: + raise ValueError(f"Arguments {used_args} cannot be used together") + + # Convert single value to list + for arg in used_args: + if isinstance(values[arg], str): + values[arg] = [values[arg]] + + return values + + async def __call__( + self, obj: Union[Message, CallbackQuery, InlineQuery, Poll] + ) -> Union[bool, Dict[str, Any]]: + if isinstance(obj, Message): + text = obj.text or obj.caption or "" + if not text and obj.poll: + text = obj.poll.question + elif isinstance(obj, CallbackQuery) and obj.data: + text = obj.data + elif isinstance(obj, InlineQuery): + text = obj.query + elif isinstance(obj, Poll): + text = obj.question + else: + return False + + if not text: + return False + if self.text_ignore_case: + text = text.lower() + + if self.text is not None: + equals = list(map(self.prepare_text, self.text)) + return text in equals + + if self.text_contains is not None: + contains = list(map(self.prepare_text, self.text_contains)) + return all(map(text.__contains__, contains)) + + if self.text_startswith is not None: + startswith = list(map(self.prepare_text, self.text_startswith)) + return any(map(text.startswith, startswith)) + + if self.text_endswith is not None: + endswith = list(map(self.prepare_text, self.text_endswith)) + return any(map(text.endswith, endswith)) + + # Impossible because the validator prevents this situation + return False # pragma: no cover + + def prepare_text(self, text: str): + if self.text_ignore_case: + return str(text).lower() + else: + return str(text) diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 70481595..d8c73f4c 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from ..api.types import Chat, Update, User from .event.observer import EventObserver, SkipHandler, TelegramEventObserver +from .filters import BUILTIN_FILTERS class Router: @@ -38,7 +39,8 @@ class Router: self.startup = EventObserver() self.shutdown = EventObserver() - self.observers = { + self.observers: Dict[str, TelegramEventObserver] = { + "update": self.update_handler, "message": self.message_handler, "edited_message": self.edited_message_handler, "channel_post": self.channel_post_handler, @@ -52,6 +54,9 @@ class Router: } self.update_handler.register(self._listen_update) + for name, observer in self.observers.items(): + for builtin_filter in BUILTIN_FILTERS.get(name, ()): + observer.bind_filter(builtin_filter) @property def parent_router(self) -> Optional[Router]: diff --git a/docs/dispatcher/filters/base_filter.md b/docs/dispatcher/filters/base.md similarity index 100% rename from docs/dispatcher/filters/base_filter.md rename to docs/dispatcher/filters/base.md diff --git a/docs/dispatcher/filters/text.md b/docs/dispatcher/filters/text.md new file mode 100644 index 00000000..1149be73 --- /dev/null +++ b/docs/dispatcher/filters/text.md @@ -0,0 +1,47 @@ +# Text + +This filter can be used for filter text [Message](../../api/types/message.md), +any [CallbackQuery](../../api/types/callback_query.md) with `data`, +[InlineQuery](../../api/types/inline_query.md) or +[Poll](../../api/types/poll.md) question. + +Can be imported: + +- `#!python3 from aiogram.dispatcher.filters.text import Text` +- `#!python3 from aiogram.dispatcher.filters import Text` +- `#!python3 from aiogram.filters import Text` + +## Specification + +| Argument | Type | Description | +| --- | --- | --- | +| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values | +| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values | +| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values | +| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values | +| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) | + +!!! warning + + Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once. + Any of that arguments can be string, list, set or tuple of strings. + +## Usage + +1. Text equals with the specified value: `#!python3 Text(text="text") # value == 'text'` +1. Text starts with the specified value: `#!python3 Text(text_startswith="text") # value.startswith('text')` +1. Text ends with the specified value: `#!python3 Text(text_endswith="text") # value.endswith('text')` +1. Text contains the specified value: `#!python3 Text(text_endswith="text") # value in 'text'` +1. Any of previous listed filters can be list, set or tuple of strings that's mean any of listed value should be equals/startswith/endswith/contains: `#!python3 Text(text=["text", "spam"])` +1. Ignore case can be combined with any previous listed filter: `#!python3 Text(text="Text", text_ignore_case=True) # value.lower() == 'text'.lower()` + +## Allowed handlers + +Allowed update types for this filter: + +- `message` +- `edited_message` +- `channel_post` +- `edited_channel_post` +- `inline_query` +- `callback_query` diff --git a/docs/dispatcher/observer.md b/docs/dispatcher/observer.md index fc7ff170..d4cc7796 100644 --- a/docs/dispatcher/observer.md +++ b/docs/dispatcher/observer.md @@ -39,7 +39,7 @@ In this handler can be bounded filters which can be used as keyword arguments in ### Registering bound filters -Bound filter should be subclass of [BaseFilter](filters/base_filter.md) +Bound filter should be subclass of [BaseFilter](filters/base.md) `#!python3 .bind_filter(MyFilter)` diff --git a/mkdocs.yml b/mkdocs.yml index 06186614..969c70ac 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -224,7 +224,8 @@ nav: - dispatcher/observer.md - Filters: - dispatcher/filters/index.md - - dispatcher/filters/base_filter.md + - dispatcher/filters/base.md + - dispatcher/filters/text.md - Build reports: - reports.md diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index ec563d9d..73327c62 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -139,7 +139,7 @@ class TestTelegramEventObserver: event_observer.bind_filter(MyFilter) assert event_observer.filters - assert event_observer.filters[0] == MyFilter + assert MyFilter in event_observer.filters def test_resolve_filters_chain(self): router1 = Router() @@ -157,9 +157,13 @@ class TestTelegramEventObserver: filters_chain2 = list(router2.message_handler._resolve_filters_chain()) filters_chain3 = list(router3.message_handler._resolve_filters_chain()) - assert filters_chain1 == [MyFilter1, MyFilter2] - assert filters_chain2 == [MyFilter2, MyFilter1] - assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1] + assert MyFilter1 in filters_chain1 + assert MyFilter1 in filters_chain2 + assert MyFilter1 in filters_chain3 + assert MyFilter2 in filters_chain1 + assert MyFilter2 in filters_chain2 + assert MyFilter2 in filters_chain3 + assert MyFilter3 in filters_chain3 def test_resolve_filters(self): router = Router() diff --git a/tests/test_dispatcher/test_filters/test_text.py b/tests/test_dispatcher/test_filters/test_text.py new file mode 100644 index 00000000..df1d26c1 --- /dev/null +++ b/tests/test_dispatcher/test_filters/test_text.py @@ -0,0 +1,233 @@ +import datetime +from itertools import permutations +from typing import Type + +import pytest +from pydantic import ValidationError + +from aiogram.api.types import CallbackQuery, Chat, InlineQuery, Message, Poll, PollOption, User +from aiogram.dispatcher.filters import BUILTIN_FILTERS +from aiogram.dispatcher.filters.text import Text + + +class TestText: + def test_default_for_observer(self): + registered_for = { + update_type for update_type, filters in BUILTIN_FILTERS.items() if Text in filters + } + assert registered_for == { + "message", + "edited_message", + "channel_post", + "edited_channel_post", + "inline_query", + "callback_query", + } + + def test_validator_not_enough_arguments(self): + with pytest.raises(ValidationError): + Text() + with pytest.raises(ValidationError): + Text(text_ignore_case=True) + + @pytest.mark.parametrize( + "first,last", + permutations(["text", "text_contains", "text_startswith", "text_endswith"], 2), + ) + @pytest.mark.parametrize("ignore_case", [True, False]) + def test_validator_too_few_arguments(self, first, last, ignore_case): + kwargs = {first: "test", last: "test"} + if ignore_case: + kwargs["text_ignore_case"] = True + + with pytest.raises(ValidationError): + Text(**kwargs) + + @pytest.mark.parametrize( + "argument", ["text", "text_contains", "text_startswith", "text_endswith"] + ) + @pytest.mark.parametrize("input_type", [str, list, tuple, set]) + def test_validator_convert_to_list(self, argument: str, input_type: Type): + text = Text(**{argument: input_type("test")}) + assert hasattr(text, argument) + assert isinstance(getattr(text, argument), list) + + @pytest.mark.parametrize( + "argument,ignore_case,input_value,update_type,result", + [ + [ + "text", + False, + "test", + Message( + message_id=42, + date=datetime.datetime.now(), + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + False, + ], + [ + "text", + False, + "test", + Message( + message_id=42, + date=datetime.datetime.now(), + caption="test", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text", + False, + "test", + Message( + message_id=42, + date=datetime.datetime.now(), + text="test", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text", + True, + "TEst", + Message( + message_id=42, + date=datetime.datetime.now(), + text="tesT", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text", + False, + "TEst", + Message( + message_id=42, + date=datetime.datetime.now(), + text="tesT", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + False, + ], + [ + "text_startswith", + False, + "test", + Message( + message_id=42, + date=datetime.datetime.now(), + text="test case", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text_endswith", + False, + "case", + Message( + message_id=42, + date=datetime.datetime.now(), + text="test case", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text_contains", + False, + " ", + Message( + message_id=42, + date=datetime.datetime.now(), + text="test case", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text_startswith", + True, + "question", + Message( + message_id=42, + date=datetime.datetime.now(), + poll=Poll( + id="poll id", + question="Question?", + options=[PollOption(text="A", voter_count=0)], + is_closed=False, + ), + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + True, + ], + [ + "text_startswith", + True, + "callback:", + CallbackQuery( + id="query id", + from_user=User(id=42, is_bot=False, first_name="Test"), + chat_instance="instance", + data="callback:data", + ), + True, + ], + [ + "text_startswith", + True, + "query", + InlineQuery( + id="query id", + from_user=User(id=42, is_bot=False, first_name="Test"), + query="query line", + offset="offset", + ), + True, + ], + [ + "text", + True, + "question", + Poll( + id="poll id", + question="Question", + options=[PollOption(text="A", voter_count=0)], + is_closed=False, + ), + True, + ], + [ + "text", + True, + ["question", "another question"], + Poll( + id="poll id", + question="Another question", + options=[PollOption(text="A", voter_count=0)], + is_closed=False, + ), + True, + ], + ["text", True, ["question", "another question"], object(), False], + ], + ) + @pytest.mark.asyncio + async def test_check_text(self, argument, ignore_case, input_value, result, update_type): + text = Text(**{argument: input_value}, text_ignore_case=ignore_case) + assert await text(obj=update_type) is result