mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Add text filter and mechanism for registering builtin filters
This commit is contained in:
parent
e37395b161
commit
40b6a61e70
9 changed files with 398 additions and 8 deletions
|
|
@ -0,0 +1,20 @@
|
||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
from .base import BaseFilter
|
||||||
|
from .text import Text
|
||||||
|
|
||||||
|
__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text")
|
||||||
|
|
||||||
|
BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = {
|
||||||
|
"update": (),
|
||||||
|
"message": (Text,),
|
||||||
|
"edited_message": (Text,),
|
||||||
|
"channel_post": (Text,),
|
||||||
|
"edited_channel_post": (Text,),
|
||||||
|
"inline_query": (Text,),
|
||||||
|
"chosen_inline_result": (),
|
||||||
|
"callback_query": (Text,),
|
||||||
|
"shipping_query": (),
|
||||||
|
"pre_checkout_query": (),
|
||||||
|
"poll": (),
|
||||||
|
}
|
||||||
80
aiogram/dispatcher/filters/text.py
Normal file
80
aiogram/dispatcher/filters/text.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
|
||||||
|
from aiogram.dispatcher.filters import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
|
class Text(BaseFilter):
|
||||||
|
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||||
|
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||||
|
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||||
|
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||||
|
text_ignore_case: bool = False
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
# Validate that only one text filter type is presented
|
||||||
|
used_args = set(
|
||||||
|
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
||||||
|
)
|
||||||
|
if len(used_args) < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
|
||||||
|
)
|
||||||
|
if len(used_args) > 1:
|
||||||
|
raise ValueError(f"Arguments {used_args} cannot be used together")
|
||||||
|
|
||||||
|
# Convert single value to list
|
||||||
|
for arg in used_args:
|
||||||
|
if isinstance(values[arg], str):
|
||||||
|
values[arg] = [values[arg]]
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
|
||||||
|
) -> Union[bool, Dict[str, Any]]:
|
||||||
|
if isinstance(obj, Message):
|
||||||
|
text = obj.text or obj.caption or ""
|
||||||
|
if not text and obj.poll:
|
||||||
|
text = obj.poll.question
|
||||||
|
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||||
|
text = obj.data
|
||||||
|
elif isinstance(obj, InlineQuery):
|
||||||
|
text = obj.query
|
||||||
|
elif isinstance(obj, Poll):
|
||||||
|
text = obj.question
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
if self.text_ignore_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
if self.text is not None:
|
||||||
|
equals = list(map(self.prepare_text, self.text))
|
||||||
|
return text in equals
|
||||||
|
|
||||||
|
if self.text_contains is not None:
|
||||||
|
contains = list(map(self.prepare_text, self.text_contains))
|
||||||
|
return all(map(text.__contains__, contains))
|
||||||
|
|
||||||
|
if self.text_startswith is not None:
|
||||||
|
startswith = list(map(self.prepare_text, self.text_startswith))
|
||||||
|
return any(map(text.startswith, startswith))
|
||||||
|
|
||||||
|
if self.text_endswith is not None:
|
||||||
|
endswith = list(map(self.prepare_text, self.text_endswith))
|
||||||
|
return any(map(text.endswith, endswith))
|
||||||
|
|
||||||
|
# Impossible because the validator prevents this situation
|
||||||
|
return False # pragma: no cover
|
||||||
|
|
||||||
|
def prepare_text(self, text: str):
|
||||||
|
if self.text_ignore_case:
|
||||||
|
return str(text).lower()
|
||||||
|
else:
|
||||||
|
return str(text)
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from ..api.types import Chat, Update, User
|
from ..api.types import Chat, Update, User
|
||||||
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
||||||
|
from .filters import BUILTIN_FILTERS
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
|
@ -38,7 +39,8 @@ class Router:
|
||||||
self.startup = EventObserver()
|
self.startup = EventObserver()
|
||||||
self.shutdown = EventObserver()
|
self.shutdown = EventObserver()
|
||||||
|
|
||||||
self.observers = {
|
self.observers: Dict[str, TelegramEventObserver] = {
|
||||||
|
"update": self.update_handler,
|
||||||
"message": self.message_handler,
|
"message": self.message_handler,
|
||||||
"edited_message": self.edited_message_handler,
|
"edited_message": self.edited_message_handler,
|
||||||
"channel_post": self.channel_post_handler,
|
"channel_post": self.channel_post_handler,
|
||||||
|
|
@ -52,6 +54,9 @@ class Router:
|
||||||
}
|
}
|
||||||
|
|
||||||
self.update_handler.register(self._listen_update)
|
self.update_handler.register(self._listen_update)
|
||||||
|
for name, observer in self.observers.items():
|
||||||
|
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
|
||||||
|
observer.bind_filter(builtin_filter)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parent_router(self) -> Optional[Router]:
|
def parent_router(self) -> Optional[Router]:
|
||||||
|
|
|
||||||
47
docs/dispatcher/filters/text.md
Normal file
47
docs/dispatcher/filters/text.md
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Text
|
||||||
|
|
||||||
|
This filter can be used for filter text [Message](../../api/types/message.md),
|
||||||
|
any [CallbackQuery](../../api/types/callback_query.md) with `data`,
|
||||||
|
[InlineQuery](../../api/types/inline_query.md) or
|
||||||
|
[Poll](../../api/types/poll.md) question.
|
||||||
|
|
||||||
|
Can be imported:
|
||||||
|
|
||||||
|
- `#!python3 from aiogram.dispatcher.filters.text import Text`
|
||||||
|
- `#!python3 from aiogram.dispatcher.filters import Text`
|
||||||
|
- `#!python3 from aiogram.filters import Text`
|
||||||
|
|
||||||
|
## Specification
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
|
||||||
|
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
|
||||||
|
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
|
||||||
|
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
|
||||||
|
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
|
||||||
|
Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once.
|
||||||
|
Any of that arguments can be string, list, set or tuple of strings.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. Text equals with the specified value: `#!python3 Text(text="text") # value == 'text'`
|
||||||
|
1. Text starts with the specified value: `#!python3 Text(text_startswith="text") # value.startswith('text')`
|
||||||
|
1. Text ends with the specified value: `#!python3 Text(text_endswith="text") # value.endswith('text')`
|
||||||
|
1. Text contains the specified value: `#!python3 Text(text_endswith="text") # value in 'text'`
|
||||||
|
1. Any of previous listed filters can be list, set or tuple of strings that's mean any of listed value should be equals/startswith/endswith/contains: `#!python3 Text(text=["text", "spam"])`
|
||||||
|
1. Ignore case can be combined with any previous listed filter: `#!python3 Text(text="Text", text_ignore_case=True) # value.lower() == 'text'.lower()`
|
||||||
|
|
||||||
|
## Allowed handlers
|
||||||
|
|
||||||
|
Allowed update types for this filter:
|
||||||
|
|
||||||
|
- `message`
|
||||||
|
- `edited_message`
|
||||||
|
- `channel_post`
|
||||||
|
- `edited_channel_post`
|
||||||
|
- `inline_query`
|
||||||
|
- `callback_query`
|
||||||
|
|
@ -39,7 +39,7 @@ In this handler can be bounded filters which can be used as keyword arguments in
|
||||||
|
|
||||||
### Registering bound filters
|
### Registering bound filters
|
||||||
|
|
||||||
Bound filter should be subclass of [BaseFilter](filters/base_filter.md)
|
Bound filter should be subclass of [BaseFilter](filters/base.md)
|
||||||
|
|
||||||
`#!python3 <observer>.bind_filter(MyFilter)`
|
`#!python3 <observer>.bind_filter(MyFilter)`
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,8 @@ nav:
|
||||||
- dispatcher/observer.md
|
- dispatcher/observer.md
|
||||||
- Filters:
|
- Filters:
|
||||||
- dispatcher/filters/index.md
|
- dispatcher/filters/index.md
|
||||||
- dispatcher/filters/base_filter.md
|
- dispatcher/filters/base.md
|
||||||
|
- dispatcher/filters/text.md
|
||||||
|
|
||||||
- Build reports:
|
- Build reports:
|
||||||
- reports.md
|
- reports.md
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,7 @@ class TestTelegramEventObserver:
|
||||||
|
|
||||||
event_observer.bind_filter(MyFilter)
|
event_observer.bind_filter(MyFilter)
|
||||||
assert event_observer.filters
|
assert event_observer.filters
|
||||||
assert event_observer.filters[0] == MyFilter
|
assert MyFilter in event_observer.filters
|
||||||
|
|
||||||
def test_resolve_filters_chain(self):
|
def test_resolve_filters_chain(self):
|
||||||
router1 = Router()
|
router1 = Router()
|
||||||
|
|
@ -157,9 +157,13 @@ class TestTelegramEventObserver:
|
||||||
filters_chain2 = list(router2.message_handler._resolve_filters_chain())
|
filters_chain2 = list(router2.message_handler._resolve_filters_chain())
|
||||||
filters_chain3 = list(router3.message_handler._resolve_filters_chain())
|
filters_chain3 = list(router3.message_handler._resolve_filters_chain())
|
||||||
|
|
||||||
assert filters_chain1 == [MyFilter1, MyFilter2]
|
assert MyFilter1 in filters_chain1
|
||||||
assert filters_chain2 == [MyFilter2, MyFilter1]
|
assert MyFilter1 in filters_chain2
|
||||||
assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1]
|
assert MyFilter1 in filters_chain3
|
||||||
|
assert MyFilter2 in filters_chain1
|
||||||
|
assert MyFilter2 in filters_chain2
|
||||||
|
assert MyFilter2 in filters_chain3
|
||||||
|
assert MyFilter3 in filters_chain3
|
||||||
|
|
||||||
def test_resolve_filters(self):
|
def test_resolve_filters(self):
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
|
||||||
233
tests/test_dispatcher/test_filters/test_text.py
Normal file
233
tests/test_dispatcher/test_filters/test_text.py
Normal file
|
|
@ -0,0 +1,233 @@
|
||||||
|
import datetime
|
||||||
|
from itertools import permutations
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from aiogram.api.types import CallbackQuery, Chat, InlineQuery, Message, Poll, PollOption, User
|
||||||
|
from aiogram.dispatcher.filters import BUILTIN_FILTERS
|
||||||
|
from aiogram.dispatcher.filters.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
class TestText:
|
||||||
|
def test_default_for_observer(self):
|
||||||
|
registered_for = {
|
||||||
|
update_type for update_type, filters in BUILTIN_FILTERS.items() if Text in filters
|
||||||
|
}
|
||||||
|
assert registered_for == {
|
||||||
|
"message",
|
||||||
|
"edited_message",
|
||||||
|
"channel_post",
|
||||||
|
"edited_channel_post",
|
||||||
|
"inline_query",
|
||||||
|
"callback_query",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_validator_not_enough_arguments(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Text()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Text(text_ignore_case=True)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"first,last",
|
||||||
|
permutations(["text", "text_contains", "text_startswith", "text_endswith"], 2),
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("ignore_case", [True, False])
|
||||||
|
def test_validator_too_few_arguments(self, first, last, ignore_case):
|
||||||
|
kwargs = {first: "test", last: "test"}
|
||||||
|
if ignore_case:
|
||||||
|
kwargs["text_ignore_case"] = True
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Text(**kwargs)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
|
||||||
|
def test_validator_convert_to_list(self, argument: str, input_type: Type):
|
||||||
|
text = Text(**{argument: input_type("test")})
|
||||||
|
assert hasattr(text, argument)
|
||||||
|
assert isinstance(getattr(text, argument), list)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"argument,ignore_case,input_value,update_type,result",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
False,
|
||||||
|
"test",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
False,
|
||||||
|
"test",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
caption="test",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
False,
|
||||||
|
"test",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="test",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
True,
|
||||||
|
"TEst",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="tesT",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
False,
|
||||||
|
"TEst",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="tesT",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_startswith",
|
||||||
|
False,
|
||||||
|
"test",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="test case",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_endswith",
|
||||||
|
False,
|
||||||
|
"case",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="test case",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_contains",
|
||||||
|
False,
|
||||||
|
" ",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
text="test case",
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_startswith",
|
||||||
|
True,
|
||||||
|
"question",
|
||||||
|
Message(
|
||||||
|
message_id=42,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
poll=Poll(
|
||||||
|
id="poll id",
|
||||||
|
question="Question?",
|
||||||
|
options=[PollOption(text="A", voter_count=0)],
|
||||||
|
is_closed=False,
|
||||||
|
),
|
||||||
|
chat=Chat(id=42, type="private"),
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_startswith",
|
||||||
|
True,
|
||||||
|
"callback:",
|
||||||
|
CallbackQuery(
|
||||||
|
id="query id",
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
chat_instance="instance",
|
||||||
|
data="callback:data",
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text_startswith",
|
||||||
|
True,
|
||||||
|
"query",
|
||||||
|
InlineQuery(
|
||||||
|
id="query id",
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||||
|
query="query line",
|
||||||
|
offset="offset",
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
True,
|
||||||
|
"question",
|
||||||
|
Poll(
|
||||||
|
id="poll id",
|
||||||
|
question="Question",
|
||||||
|
options=[PollOption(text="A", voter_count=0)],
|
||||||
|
is_closed=False,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
True,
|
||||||
|
["question", "another question"],
|
||||||
|
Poll(
|
||||||
|
id="poll id",
|
||||||
|
question="Another question",
|
||||||
|
options=[PollOption(text="A", voter_count=0)],
|
||||||
|
is_closed=False,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
["text", True, ["question", "another question"], object(), False],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_text(self, argument, ignore_case, input_value, result, update_type):
|
||||||
|
text = Text(**{argument: input_value}, text_ignore_case=ignore_case)
|
||||||
|
assert await text(obj=update_type) is result
|
||||||
Loading…
Add table
Add a link
Reference in a new issue