Add text filter and mechanism for registering builtin filters

This commit is contained in:
Alex Root Junior 2019-11-29 23:16:11 +02:00
parent e37395b161
commit 40b6a61e70
9 changed files with 398 additions and 8 deletions

View file

@ -0,0 +1,20 @@
from typing import Dict, Tuple, Union
from .base import BaseFilter
from .text import Text
__all__ = ("BUILTIN_FILTERS", "BaseFilter", "Text")
BUILTIN_FILTERS: Dict[str, Union[Tuple[BaseFilter], Tuple]] = {
"update": (),
"message": (Text,),
"edited_message": (Text,),
"channel_post": (Text,),
"edited_channel_post": (Text,),
"inline_query": (Text,),
"chosen_inline_result": (),
"callback_query": (Text,),
"shipping_query": (),
"pre_checkout_query": (),
"poll": (),
}

View file

@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import root_validator
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
from aiogram.dispatcher.filters import BaseFilter
class Text(BaseFilter):
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_ignore_case: bool = False
@root_validator
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Validate that only one text filter type is presented
used_args = set(
key for key, value in values.items() if key != "text_ignore_case" and value is not None
)
if len(used_args) < 1:
raise ValueError(
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
)
if len(used_args) > 1:
raise ValueError(f"Arguments {used_args} cannot be used together")
# Convert single value to list
for arg in used_args:
if isinstance(values[arg], str):
values[arg] = [values[arg]]
return values
async def __call__(
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
) -> Union[bool, Dict[str, Any]]:
if isinstance(obj, Message):
text = obj.text or obj.caption or ""
if not text and obj.poll:
text = obj.poll.question
elif isinstance(obj, CallbackQuery) and obj.data:
text = obj.data
elif isinstance(obj, InlineQuery):
text = obj.query
elif isinstance(obj, Poll):
text = obj.question
else:
return False
if not text:
return False
if self.text_ignore_case:
text = text.lower()
if self.text is not None:
equals = list(map(self.prepare_text, self.text))
return text in equals
if self.text_contains is not None:
contains = list(map(self.prepare_text, self.text_contains))
return all(map(text.__contains__, contains))
if self.text_startswith is not None:
startswith = list(map(self.prepare_text, self.text_startswith))
return any(map(text.startswith, startswith))
if self.text_endswith is not None:
endswith = list(map(self.prepare_text, self.text_endswith))
return any(map(text.endswith, endswith))
# Impossible because the validator prevents this situation
return False # pragma: no cover
def prepare_text(self, text: str):
if self.text_ignore_case:
return str(text).lower()
else:
return str(text)

View file

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from ..api.types import Chat, Update, User
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
from .filters import BUILTIN_FILTERS
class Router:
@ -38,7 +39,8 @@ class Router:
self.startup = EventObserver()
self.shutdown = EventObserver()
self.observers = {
self.observers: Dict[str, TelegramEventObserver] = {
"update": self.update_handler,
"message": self.message_handler,
"edited_message": self.edited_message_handler,
"channel_post": self.channel_post_handler,
@ -52,6 +54,9 @@ class Router:
}
self.update_handler.register(self._listen_update)
for name, observer in self.observers.items():
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)
@property
def parent_router(self) -> Optional[Router]:

View file

@ -0,0 +1,47 @@
# Text
This filter can be used for filter text [Message](../../api/types/message.md),
any [CallbackQuery](../../api/types/callback_query.md) with `data`,
[InlineQuery](../../api/types/inline_query.md) or
[Poll](../../api/types/poll.md) question.
Can be imported:
- `#!python3 from aiogram.dispatcher.filters.text import Text`
- `#!python3 from aiogram.dispatcher.filters import Text`
- `#!python3 from aiogram.filters import Text`
## Specification
| Argument | Type | Description |
| --- | --- | --- |
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
!!! warning
Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once.
Any of that arguments can be string, list, set or tuple of strings.
## Usage
1. Text equals with the specified value: `#!python3 Text(text="text") # value == 'text'`
1. Text starts with the specified value: `#!python3 Text(text_startswith="text") # value.startswith('text')`
1. Text ends with the specified value: `#!python3 Text(text_endswith="text") # value.endswith('text')`
1. Text contains the specified value: `#!python3 Text(text_endswith="text") # value in 'text'`
1. Any of previous listed filters can be list, set or tuple of strings that's mean any of listed value should be equals/startswith/endswith/contains: `#!python3 Text(text=["text", "spam"])`
1. Ignore case can be combined with any previous listed filter: `#!python3 Text(text="Text", text_ignore_case=True) # value.lower() == 'text'.lower()`
## Allowed handlers
Allowed update types for this filter:
- `message`
- `edited_message`
- `channel_post`
- `edited_channel_post`
- `inline_query`
- `callback_query`

View file

@ -39,7 +39,7 @@ In this handler can be bounded filters which can be used as keyword arguments in
### Registering bound filters
Bound filter should be subclass of [BaseFilter](filters/base_filter.md)
Bound filter should be subclass of [BaseFilter](filters/base.md)
`#!python3 <observer>.bind_filter(MyFilter)`

View file

@ -224,7 +224,8 @@ nav:
- dispatcher/observer.md
- Filters:
- dispatcher/filters/index.md
- dispatcher/filters/base_filter.md
- dispatcher/filters/base.md
- dispatcher/filters/text.md
- Build reports:
- reports.md

View file

@ -139,7 +139,7 @@ class TestTelegramEventObserver:
event_observer.bind_filter(MyFilter)
assert event_observer.filters
assert event_observer.filters[0] == MyFilter
assert MyFilter in event_observer.filters
def test_resolve_filters_chain(self):
router1 = Router()
@ -157,9 +157,13 @@ class TestTelegramEventObserver:
filters_chain2 = list(router2.message_handler._resolve_filters_chain())
filters_chain3 = list(router3.message_handler._resolve_filters_chain())
assert filters_chain1 == [MyFilter1, MyFilter2]
assert filters_chain2 == [MyFilter2, MyFilter1]
assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1]
assert MyFilter1 in filters_chain1
assert MyFilter1 in filters_chain2
assert MyFilter1 in filters_chain3
assert MyFilter2 in filters_chain1
assert MyFilter2 in filters_chain2
assert MyFilter2 in filters_chain3
assert MyFilter3 in filters_chain3
def test_resolve_filters(self):
router = Router()

View file

@ -0,0 +1,233 @@
import datetime
from itertools import permutations
from typing import Type
import pytest
from pydantic import ValidationError
from aiogram.api.types import CallbackQuery, Chat, InlineQuery, Message, Poll, PollOption, User
from aiogram.dispatcher.filters import BUILTIN_FILTERS
from aiogram.dispatcher.filters.text import Text
class TestText:
def test_default_for_observer(self):
registered_for = {
update_type for update_type, filters in BUILTIN_FILTERS.items() if Text in filters
}
assert registered_for == {
"message",
"edited_message",
"channel_post",
"edited_channel_post",
"inline_query",
"callback_query",
}
def test_validator_not_enough_arguments(self):
with pytest.raises(ValidationError):
Text()
with pytest.raises(ValidationError):
Text(text_ignore_case=True)
@pytest.mark.parametrize(
"first,last",
permutations(["text", "text_contains", "text_startswith", "text_endswith"], 2),
)
@pytest.mark.parametrize("ignore_case", [True, False])
def test_validator_too_few_arguments(self, first, last, ignore_case):
kwargs = {first: "test", last: "test"}
if ignore_case:
kwargs["text_ignore_case"] = True
with pytest.raises(ValidationError):
Text(**kwargs)
@pytest.mark.parametrize(
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
)
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
def test_validator_convert_to_list(self, argument: str, input_type: Type):
text = Text(**{argument: input_type("test")})
assert hasattr(text, argument)
assert isinstance(getattr(text, argument), list)
@pytest.mark.parametrize(
"argument,ignore_case,input_value,update_type,result",
[
[
"text",
False,
"test",
Message(
message_id=42,
date=datetime.datetime.now(),
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
False,
],
[
"text",
False,
"test",
Message(
message_id=42,
date=datetime.datetime.now(),
caption="test",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text",
False,
"test",
Message(
message_id=42,
date=datetime.datetime.now(),
text="test",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text",
True,
"TEst",
Message(
message_id=42,
date=datetime.datetime.now(),
text="tesT",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text",
False,
"TEst",
Message(
message_id=42,
date=datetime.datetime.now(),
text="tesT",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
False,
],
[
"text_startswith",
False,
"test",
Message(
message_id=42,
date=datetime.datetime.now(),
text="test case",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text_endswith",
False,
"case",
Message(
message_id=42,
date=datetime.datetime.now(),
text="test case",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text_contains",
False,
" ",
Message(
message_id=42,
date=datetime.datetime.now(),
text="test case",
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text_startswith",
True,
"question",
Message(
message_id=42,
date=datetime.datetime.now(),
poll=Poll(
id="poll id",
question="Question?",
options=[PollOption(text="A", voter_count=0)],
is_closed=False,
),
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
True,
],
[
"text_startswith",
True,
"callback:",
CallbackQuery(
id="query id",
from_user=User(id=42, is_bot=False, first_name="Test"),
chat_instance="instance",
data="callback:data",
),
True,
],
[
"text_startswith",
True,
"query",
InlineQuery(
id="query id",
from_user=User(id=42, is_bot=False, first_name="Test"),
query="query line",
offset="offset",
),
True,
],
[
"text",
True,
"question",
Poll(
id="poll id",
question="Question",
options=[PollOption(text="A", voter_count=0)],
is_closed=False,
),
True,
],
[
"text",
True,
["question", "another question"],
Poll(
id="poll id",
question="Another question",
options=[PollOption(text="A", voter_count=0)],
is_closed=False,
),
True,
],
["text", True, ["question", "another question"], object(), False],
],
)
@pytest.mark.asyncio
async def test_check_text(self, argument, ignore_case, input_value, result, update_type):
text = Text(**{argument: input_value}, text_ignore_case=ignore_case)
assert await text(obj=update_type) is result