mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Add text filter and mechanism for registering builtin filters
This commit is contained in:
parent
e37395b161
commit
40b6a61e70
9 changed files with 398 additions and 8 deletions
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Dict, Tuple, Union
|
||||
|
||||
from .base import BaseFilter
|
||||
from .text import Text
|
||||
|
||||
__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text")
|
||||
|
||||
BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = {
|
||||
"update": (),
|
||||
"message": (Text,),
|
||||
"edited_message": (Text,),
|
||||
"channel_post": (Text,),
|
||||
"edited_channel_post": (Text,),
|
||||
"inline_query": (Text,),
|
||||
"chosen_inline_result": (),
|
||||
"callback_query": (Text,),
|
||||
"shipping_query": (),
|
||||
"pre_checkout_query": (),
|
||||
"poll": (),
|
||||
}
|
||||
80
aiogram/dispatcher/filters/text.py
Normal file
80
aiogram/dispatcher/filters/text.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
|
||||
from aiogram.dispatcher.filters import BaseFilter
|
||||
|
||||
|
||||
class Text(BaseFilter):
|
||||
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_ignore_case: bool = False
|
||||
|
||||
@root_validator
|
||||
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Validate that only one text filter type is presented
|
||||
used_args = set(
|
||||
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
||||
)
|
||||
if len(used_args) < 1:
|
||||
raise ValueError(
|
||||
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
|
||||
)
|
||||
if len(used_args) > 1:
|
||||
raise ValueError(f"Arguments {used_args} cannot be used together")
|
||||
|
||||
# Convert single value to list
|
||||
for arg in used_args:
|
||||
if isinstance(values[arg], str):
|
||||
values[arg] = [values[arg]]
|
||||
|
||||
return values
|
||||
|
||||
async def __call__(
|
||||
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
|
||||
) -> Union[bool, Dict[str, Any]]:
|
||||
if isinstance(obj, Message):
|
||||
text = obj.text or obj.caption or ""
|
||||
if not text and obj.poll:
|
||||
text = obj.poll.question
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
text = obj.data
|
||||
elif isinstance(obj, InlineQuery):
|
||||
text = obj.query
|
||||
elif isinstance(obj, Poll):
|
||||
text = obj.question
|
||||
else:
|
||||
return False
|
||||
|
||||
if not text:
|
||||
return False
|
||||
if self.text_ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.text is not None:
|
||||
equals = list(map(self.prepare_text, self.text))
|
||||
return text in equals
|
||||
|
||||
if self.text_contains is not None:
|
||||
contains = list(map(self.prepare_text, self.text_contains))
|
||||
return all(map(text.__contains__, contains))
|
||||
|
||||
if self.text_startswith is not None:
|
||||
startswith = list(map(self.prepare_text, self.text_startswith))
|
||||
return any(map(text.startswith, startswith))
|
||||
|
||||
if self.text_endswith is not None:
|
||||
endswith = list(map(self.prepare_text, self.text_endswith))
|
||||
return any(map(text.endswith, endswith))
|
||||
|
||||
# Impossible because the validator prevents this situation
|
||||
return False # pragma: no cover
|
||||
|
||||
def prepare_text(self, text: str):
|
||||
if self.text_ignore_case:
|
||||
return str(text).lower()
|
||||
else:
|
||||
return str(text)
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..api.types import Chat, Update, User
|
||||
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
|
||||
|
||||
class Router:
|
||||
|
|
@ -38,7 +39,8 @@ class Router:
|
|||
self.startup = EventObserver()
|
||||
self.shutdown = EventObserver()
|
||||
|
||||
self.observers = {
|
||||
self.observers: Dict[str, TelegramEventObserver] = {
|
||||
"update": self.update_handler,
|
||||
"message": self.message_handler,
|
||||
"edited_message": self.edited_message_handler,
|
||||
"channel_post": self.channel_post_handler,
|
||||
|
|
@ -52,6 +54,9 @@ class Router:
|
|||
}
|
||||
|
||||
self.update_handler.register(self._listen_update)
|
||||
for name, observer in self.observers.items():
|
||||
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
|
||||
observer.bind_filter(builtin_filter)
|
||||
|
||||
@property
|
||||
def parent_router(self) -> Optional[Router]:
|
||||
|
|
|
|||
47
docs/dispatcher/filters/text.md
Normal file
47
docs/dispatcher/filters/text.md
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Text
|
||||
|
||||
This filter can be used for filter text [Message](../../api/types/message.md),
|
||||
any [CallbackQuery](../../api/types/callback_query.md) with `data`,
|
||||
[InlineQuery](../../api/types/inline_query.md) or
|
||||
[Poll](../../api/types/poll.md) question.
|
||||
|
||||
Can be imported:
|
||||
|
||||
- `#!python3 from aiogram.dispatcher.filters.text import Text`
|
||||
- `#!python3 from aiogram.dispatcher.filters import Text`
|
||||
- `#!python3 from aiogram.filters import Text`
|
||||
|
||||
## Specification
|
||||
|
||||
| Argument | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
|
||||
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
|
||||
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
|
||||
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
|
||||
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
|
||||
|
||||
!!! warning
|
||||
|
||||
Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once.
|
||||
Any of that arguments can be string, list, set or tuple of strings.
|
||||
|
||||
## Usage
|
||||
|
||||
1. Text equals with the specified value: `#!python3 Text(text="text") # value == 'text'`
|
||||
1. Text starts with the specified value: `#!python3 Text(text_startswith="text") # value.startswith('text')`
|
||||
1. Text ends with the specified value: `#!python3 Text(text_endswith="text") # value.endswith('text')`
|
||||
1. Text contains the specified value: `#!python3 Text(text_endswith="text") # value in 'text'`
|
||||
1. Any of previous listed filters can be list, set or tuple of strings that's mean any of listed value should be equals/startswith/endswith/contains: `#!python3 Text(text=["text", "spam"])`
|
||||
1. Ignore case can be combined with any previous listed filter: `#!python3 Text(text="Text", text_ignore_case=True) # value.lower() == 'text'.lower()`
|
||||
|
||||
## Allowed handlers
|
||||
|
||||
Allowed update types for this filter:
|
||||
|
||||
- `message`
|
||||
- `edited_message`
|
||||
- `channel_post`
|
||||
- `edited_channel_post`
|
||||
- `inline_query`
|
||||
- `callback_query`
|
||||
|
|
@ -39,7 +39,7 @@ In this handler can be bounded filters which can be used as keyword arguments in
|
|||
|
||||
### Registering bound filters
|
||||
|
||||
Bound filter should be subclass of [BaseFilter](filters/base_filter.md)
|
||||
Bound filter should be subclass of [BaseFilter](filters/base.md)
|
||||
|
||||
`#!python3 <observer>.bind_filter(MyFilter)`
|
||||
|
||||
|
|
|
|||
|
|
@ -224,7 +224,8 @@ nav:
|
|||
- dispatcher/observer.md
|
||||
- Filters:
|
||||
- dispatcher/filters/index.md
|
||||
- dispatcher/filters/base_filter.md
|
||||
- dispatcher/filters/base.md
|
||||
- dispatcher/filters/text.md
|
||||
|
||||
- Build reports:
|
||||
- reports.md
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class TestTelegramEventObserver:
|
|||
|
||||
event_observer.bind_filter(MyFilter)
|
||||
assert event_observer.filters
|
||||
assert event_observer.filters[0] == MyFilter
|
||||
assert MyFilter in event_observer.filters
|
||||
|
||||
def test_resolve_filters_chain(self):
|
||||
router1 = Router()
|
||||
|
|
@ -157,9 +157,13 @@ class TestTelegramEventObserver:
|
|||
filters_chain2 = list(router2.message_handler._resolve_filters_chain())
|
||||
filters_chain3 = list(router3.message_handler._resolve_filters_chain())
|
||||
|
||||
assert filters_chain1 == [MyFilter1, MyFilter2]
|
||||
assert filters_chain2 == [MyFilter2, MyFilter1]
|
||||
assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1]
|
||||
assert MyFilter1 in filters_chain1
|
||||
assert MyFilter1 in filters_chain2
|
||||
assert MyFilter1 in filters_chain3
|
||||
assert MyFilter2 in filters_chain1
|
||||
assert MyFilter2 in filters_chain2
|
||||
assert MyFilter2 in filters_chain3
|
||||
assert MyFilter3 in filters_chain3
|
||||
|
||||
def test_resolve_filters(self):
|
||||
router = Router()
|
||||
|
|
|
|||
233
tests/test_dispatcher/test_filters/test_text.py
Normal file
233
tests/test_dispatcher/test_filters/test_text.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
import datetime
|
||||
from itertools import permutations
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiogram.api.types import CallbackQuery, Chat, InlineQuery, Message, Poll, PollOption, User
|
||||
from aiogram.dispatcher.filters import BUILTIN_FILTERS
|
||||
from aiogram.dispatcher.filters.text import Text
|
||||
|
||||
|
||||
class TestText:
|
||||
def test_default_for_observer(self):
|
||||
registered_for = {
|
||||
update_type for update_type, filters in BUILTIN_FILTERS.items() if Text in filters
|
||||
}
|
||||
assert registered_for == {
|
||||
"message",
|
||||
"edited_message",
|
||||
"channel_post",
|
||||
"edited_channel_post",
|
||||
"inline_query",
|
||||
"callback_query",
|
||||
}
|
||||
|
||||
def test_validator_not_enough_arguments(self):
|
||||
with pytest.raises(ValidationError):
|
||||
Text()
|
||||
with pytest.raises(ValidationError):
|
||||
Text(text_ignore_case=True)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"first,last",
|
||||
permutations(["text", "text_contains", "text_startswith", "text_endswith"], 2),
|
||||
)
|
||||
@pytest.mark.parametrize("ignore_case", [True, False])
|
||||
def test_validator_too_few_arguments(self, first, last, ignore_case):
|
||||
kwargs = {first: "test", last: "test"}
|
||||
if ignore_case:
|
||||
kwargs["text_ignore_case"] = True
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Text(**kwargs)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
|
||||
)
|
||||
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
|
||||
def test_validator_convert_to_list(self, argument: str, input_type: Type):
|
||||
text = Text(**{argument: input_type("test")})
|
||||
assert hasattr(text, argument)
|
||||
assert isinstance(getattr(text, argument), list)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"argument,ignore_case,input_value,update_type,result",
|
||||
[
|
||||
[
|
||||
"text",
|
||||
False,
|
||||
"test",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
False,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
False,
|
||||
"test",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
caption="test",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
False,
|
||||
"test",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="test",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
True,
|
||||
"TEst",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="tesT",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
False,
|
||||
"TEst",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="tesT",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
False,
|
||||
],
|
||||
[
|
||||
"text_startswith",
|
||||
False,
|
||||
"test",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="test case",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text_endswith",
|
||||
False,
|
||||
"case",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="test case",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text_contains",
|
||||
False,
|
||||
" ",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="test case",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text_startswith",
|
||||
True,
|
||||
"question",
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
poll=Poll(
|
||||
id="poll id",
|
||||
question="Question?",
|
||||
options=[PollOption(text="A", voter_count=0)],
|
||||
is_closed=False,
|
||||
),
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text_startswith",
|
||||
True,
|
||||
"callback:",
|
||||
CallbackQuery(
|
||||
id="query id",
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
chat_instance="instance",
|
||||
data="callback:data",
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text_startswith",
|
||||
True,
|
||||
"query",
|
||||
InlineQuery(
|
||||
id="query id",
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
query="query line",
|
||||
offset="offset",
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
True,
|
||||
"question",
|
||||
Poll(
|
||||
id="poll id",
|
||||
question="Question",
|
||||
options=[PollOption(text="A", voter_count=0)],
|
||||
is_closed=False,
|
||||
),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"text",
|
||||
True,
|
||||
["question", "another question"],
|
||||
Poll(
|
||||
id="poll id",
|
||||
question="Another question",
|
||||
options=[PollOption(text="A", voter_count=0)],
|
||||
is_closed=False,
|
||||
),
|
||||
True,
|
||||
],
|
||||
["text", True, ["question", "another question"], object(), False],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_text(self, argument, ignore_case, input_value, result, update_type):
|
||||
text = Text(**{argument: input_value}, text_ignore_case=ignore_case)
|
||||
assert await text(obj=update_type) is result
|
||||
Loading…
Add table
Add a link
Reference in a new issue