Merge remote-tracking branch 'origin/dev-3.x' into dev-3.x

This commit is contained in:
jrootjunior 2020-01-16 10:21:09 +02:00
commit da24ca6b07
16 changed files with 187 additions and 143 deletions

View file

@ -1,7 +1,6 @@
from __future__ import annotations
import copy
from typing import Any, Callable, Dict, Optional, TypeVar, cast
from typing import Callable, Optional, TypeVar, cast
from aiohttp import ClientSession, FormData
@ -60,17 +59,3 @@ class AiohttpSession(BaseSession):
async def __aenter__(self) -> AiohttpSession:
await self.create_session()
return self
def __deepcopy__(self: T, memo: Optional[Dict[int, Any]] = None) -> T:
if memo is None: # pragma: no cover
# This block was never be called
memo = {}
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for key, value in self.__dict__.items():
# aiohttp ClientSession cannot be copied.
copied_value = copy.deepcopy(value, memo=memo) if key != "_session" else None
setattr(result, key, copied_value)
return result

View file

@ -1,7 +1,7 @@
import inspect
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[CallbackType, BaseHandler]
HandlerType = Union[FilterType, BaseHandler]
@dataclass
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
callback: HandlerType
filters: List[FilterObject]
filters: Optional[List[FilterObject]] = None
def __post_init__(self):
super(HandlerObject, self).__post_init__()
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
self.awaitable = True
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters:
return True, {}
for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs)
if not check:

View file

@ -1,6 +1,6 @@
from __future__ import annotations
import copy
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
@ -34,15 +34,11 @@ class EventObserver:
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
def register(self, callback: HandlerType) -> HandlerType:
"""
Register callback with filters
"""
self.handlers.append(
HandlerObject(
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
)
)
self.handlers.append(HandlerObject(callback=callback))
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
@ -51,22 +47,18 @@ class EventObserver:
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
kwargs_copy = copy.copy(kwargs)
result, data = await handler.check(*args, **kwargs)
if result:
kwargs_copy.update(data)
try:
yield await handler.call(*args, **kwargs_copy)
except SkipHandler:
continue
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback, *args)
self.register(callback)
return callback
return wrapper
@ -148,16 +140,28 @@ class TelegramEventObserver(EventObserver):
Register event handler
"""
resolved_filters = self.resolve_filters(bound_filters)
return super().register(callback, *filters, *resolved_filters)
self.handlers.append(
HandlerObject(
callback=callback,
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
)
)
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
"""
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
yield result
break
for handler in self.handlers:
result, data = await handler.check(*args, **kwargs)
if result:
kwargs.update(data)
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
break
def __call__(
self, *args: FilterType, **bound_filters: BaseFilter

View file

@ -10,7 +10,7 @@ class BaseFilter(ABC, BaseModel):
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
__call__: Any
pass
else: # pragma: no cover
@abstractmethod

View file

@ -2,7 +2,9 @@ from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Match, Optional, Pattern, Union
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
from pydantic import validator
from aiogram import Bot
from aiogram.api.types import Message
@ -12,11 +14,19 @@ CommandPatterType = Union[str, re.Pattern] # type: ignore
class Command(BaseFilter):
commands: List[CommandPatterType]
commands: Union[Sequence[CommandPatterType], CommandPatterType]
commands_prefix: str = "/"
commands_ignore_case: bool = False
commands_ignore_mention: bool = False
@validator("commands", always=True)
def _validate_commands(
cls, value: Union[Sequence[CommandPatterType], CommandPatterType]
) -> Sequence[CommandPatterType]:
if isinstance(value, (str, re.Pattern)):
value = [value]
return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
if not message.text:
return False
@ -50,7 +60,7 @@ class Command(BaseFilter):
return False
# Validate command
for allowed_command in self.commands:
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
from pydantic import validator
@ -8,12 +8,16 @@ from .base import BaseFilter
class ContentTypesFilter(BaseFilter):
content_types: Optional[List[str]] = None
content_types: Optional[Union[Sequence[str], str]] = None
@validator("content_types", always=True)
def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]:
@validator("content_types")
def _validate_content_types(
cls, value: Optional[Union[Sequence[str], str]]
) -> Optional[Sequence[str]]:
if not value:
value = [ContentType.TEXT]
if isinstance(value, str):
value = [value]
allowed_content_types = set(ContentType.all())
bad_content_types = set(value) - allowed_content_types
if bad_content_types:
@ -23,5 +27,5 @@ class ContentTypesFilter(BaseFilter):
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
if not self.content_types: # pragma: no cover
# Is impossible but needed for valid typechecking
return False
self.content_types = [ContentType.TEXT]
return ContentType.ANY in self.content_types or message.content_type in self.content_types

View file

@ -1,10 +1,12 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Union
from pydantic import root_validator
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
from aiogram.dispatcher.filters import BaseFilter
TextType = str
class Text(BaseFilter):
"""
@ -12,14 +14,14 @@ class Text(BaseFilter):
InlineQuery or Poll question.
"""
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
text: Optional[Union[Sequence[TextType], TextType]] = None
text_contains: Optional[Union[Sequence[TextType], TextType]] = None
text_startswith: Optional[Union[Sequence[TextType], TextType]] = None
text_endswith: Optional[Union[Sequence[TextType], TextType]] = None
text_ignore_case: bool = False
@root_validator
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Validate that only one text filter type is presented
used_args = set(
key for key, value in values.items() if key != "text_ignore_case" and value is not None

View file

@ -7,7 +7,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
## Specification
| Argument | Type | Description |
| --- | --- | --- |
| `commands` | `#!python3 List[CommandPatterType]` | List of commands (string or compiled regexp patterns) |
| `commands` | `#!python3 Union[Sequence[Union[str, re.Pattern]], Union[str, re.Pattern]]` | List of commands (string or compiled regexp patterns) |
| `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). |
| `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) |
| `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) |
@ -15,7 +15,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
## Usage
1. Filter single variant of commands: `#!python3 Command(commands=["start"])`
1. Filter single variant of commands: `#!python3 Command(commands=["start"])` or `#!python3 Command(commands="start")`
1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])`
1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])`
1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)`

View file

@ -19,11 +19,11 @@ Or used from filters factory by passing corresponding arguments to handler regis
| Argument | Type | Description |
| --- | --- | --- |
| `content_types` | `#!python3 Optional[List[str]]` | List of allowed content types |
| `content_types` | `#!python3 Optional[Union[Sequence[str], str]]` | List of allowed content types |
## Usage
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])`
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` or `#!python3 ContentTypesFilter(content_types="sticker")`
1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])`
1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])`
1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])`

View file

@ -18,10 +18,10 @@ Or used from filters factory by passing corresponding arguments to handler regis
| Argument | Type | Description |
| --- | --- | --- |
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
| `text` | `#!python3 Optional[Union[Sequence[str], str]]` | Text equals value or one of values |
| `text_contains` | `#!python3 Optional[Union[Sequence[str], str]]` | Text contains value or one of values |
| `text_startswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text starts with value or one of values |
| `text_endswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text ends with value or one of values |
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
!!! warning

View file

@ -14,12 +14,11 @@ Reference: `#!python3 aiogram.dispatcher.event.observer.EventObserver`
That is base observer for all events.
### Base registering method
Method: `<observer>.register(callback, filter1, filter2, ...)`
Method: `<observer>.register()`
| Argument | Type | Description |
| --- | --- | --- |
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
Will return original callback.
@ -28,14 +27,15 @@ Will return original callback.
Usage:
```python3
@<observer>(filter1, filter2, ...)
@<observer>()
async def handler(*args, **kwargs):
pass
```
## TelegramEventObserver
Is subclass of [EventObserver](#eventobserver) with some differences.
In this handler can be bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
Here you can register handler with filters or bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
This observer will stops event propagation when first handler is pass.
### Registering bound filters
@ -44,7 +44,7 @@ Bound filter should be subclass of [BaseFilter](filters/index.md)
`#!python3 <observer>.bind_filter(MyFilter)`
### Registering handlers
Method: `EventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
Method: `TelegramEventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
In this method is added bound filters keywords interface.
| Argument | Type | Description |
@ -52,3 +52,13 @@ In this method is added bound filters keywords interface.
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
| `**bound_filters` | `#!python3 Any` | Bound filters |
### Decorator-style registering event handler with filters
Usage:
```python3
@<observer>(filter1, filter2, ...)
async def handler(*args, **kwargs):
pass
```

View file

@ -1,4 +1,3 @@
import copy
from typing import AsyncContextManager
import aiohttp
@ -123,11 +122,3 @@ class TestAiohttpSession:
assert session == ctx
mocked_close.awaited_once()
mocked_create_session.awaited_once()
@pytest.mark.asyncio
async def test_deepcopy(self):
# Session should be copied without aiohttp.ClientSession
async with AiohttpSession() as session:
cloned_session = copy.deepcopy(session)
assert cloned_session != session
assert cloned_session._session is None

View file

@ -39,68 +39,38 @@ class MyFilter3(MyFilter1):
class TestEventObserver:
@pytest.mark.parametrize(
"count,handler,filters",
(
pytest.param(5, my_handler, []),
pytest.param(3, my_handler, [lambda event: True]),
pytest.param(
2,
my_handler,
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
),
),
)
def test_register_filters(self, count, handler, filters):
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer.register(wrapped_handler, *filters)
observer.register(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
assert not registered_handler.filters
@pytest.mark.parametrize(
"count,handler,filters",
(
pytest.param(5, my_handler, []),
pytest.param(3, my_handler, [lambda event: True]),
pytest.param(
2,
my_handler,
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
),
),
)
def test_register_filters_via_decorator(self, count, handler, filters):
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters_via_decorator(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer(*filters)(wrapped_handler)
observer()(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
@pytest.mark.asyncio
async def test_trigger_rejected(self):
observer = EventObserver()
observer.register(my_handler, lambda event: False)
results = [result async for result in observer.trigger(42)]
assert results == []
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger_accepted_bool(self):
observer = EventObserver()
observer.register(my_handler, lambda event: True)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@ -108,23 +78,12 @@ class TestEventObserver:
@pytest.mark.asyncio
async def test_trigger_with_skip(self):
observer = EventObserver()
observer.register(skip_my_handler, lambda event: True)
observer.register(my_handler, lambda event: False)
observer.register(my_handler, lambda event: True)
observer.register(skip_my_handler)
observer.register(my_handler)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
observer = EventObserver()
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]
assert results == [42, 42]
class TestTelegramEventObserver:
@ -144,9 +103,9 @@ class TestTelegramEventObserver:
assert MyFilter in event_observer.filters
def test_resolve_filters_chain(self):
router1 = Router()
router2 = Router()
router3 = Router()
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router3 = Router(use_builtin_filters=False)
router1.include_router(router2)
router2.include_router(router3)
@ -168,7 +127,7 @@ class TestTelegramEventObserver:
assert MyFilter3 in filters_chain3
def test_resolve_filters(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
@ -189,7 +148,7 @@ class TestTelegramEventObserver:
assert observer.resolve_filters({"test": ...})
def test_register(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
@ -214,7 +173,7 @@ class TestTelegramEventObserver:
assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
@observer()
@ -226,7 +185,7 @@ class TestTelegramEventObserver:
@pytest.mark.asyncio
async def test_trigger(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
observer.register(my_handler, test="ok")
@ -241,3 +200,38 @@ class TestTelegramEventObserver:
results = [result async for result in observer.trigger(message)]
assert results == [message]
@pytest.mark.parametrize(
"count,handler,filters",
(
[5, my_handler, []],
[3, my_handler, [lambda event: True]],
[2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]],
),
)
def test_register_filters_via_decorator(self, count, handler, filters):
router = Router(use_builtin_filters=False)
observer = router.message_handler
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer(*filters)(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
#
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]

View file

@ -11,6 +11,13 @@ from tests.mocked_bot import MockedBot
class TestCommandFilter:
def test_convert_to_list(self):
cmd = Command(commands="start")
assert cmd.commands
assert isinstance(cmd.commands, list)
assert cmd.commands[0] == "start"
assert cmd == Command(commands=["start"])
@pytest.mark.asyncio
async def test_parse_command(self, bot: MockedBot):
# TODO: parametrize

View file

@ -1,18 +1,37 @@
from dataclasses import dataclass
from typing import cast
import pytest
from pydantic import ValidationError
from aiogram.api.types import ContentType, Message
from aiogram.dispatcher.filters import ContentTypesFilter
@dataclass
class MinimalMessage:
content_type: str
class TestContentTypesFilter:
def test_validator_empty(self):
@pytest.mark.asyncio
async def test_validator_empty(self):
filter_ = ContentTypesFilter()
assert filter_.content_types == ["text"]
assert not filter_.content_types
await filter_(cast(Message, MinimalMessage(ContentType.TEXT)))
assert filter_.content_types == [ContentType.TEXT]
def test_validator_empty_list(self):
filter_ = ContentTypesFilter(content_types=[])
assert filter_.content_types == ["text"]
def test_convert_to_list(self):
filter_ = ContentTypesFilter(content_types="text")
assert filter_.content_types
assert isinstance(filter_.content_types, list)
assert filter_.content_types[0] == "text"
assert filter_ == ContentTypesFilter(content_types=["text"])
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
def test_validator_with_values(self, values):
filter_ = ContentTypesFilter(content_types=values)
@ -22,3 +41,19 @@ class TestContentTypesFilter:
def test_validator_with_bad_values(self, values):
with pytest.raises(ValidationError):
ContentTypesFilter(content_types=values)
@pytest.mark.parametrize(
"values,content_type,result",
[
[[], ContentType.TEXT, True],
[[ContentType.TEXT], ContentType.TEXT, True],
[[ContentType.PHOTO], ContentType.TEXT, False],
[[ContentType.ANY], ContentType.TEXT, True],
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
],
)
@pytest.mark.asyncio
async def test_call(self, values, content_type, result):
filter_ = ContentTypesFilter(content_types=values)
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result

View file

@ -1,6 +1,6 @@
import datetime
from itertools import permutations
from typing import Type
from typing import Sequence, Type
import pytest
from pydantic import ValidationError
@ -46,11 +46,11 @@ class TestText:
@pytest.mark.parametrize(
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
)
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
@pytest.mark.parametrize("input_type", [str, list, tuple])
def test_validator_convert_to_list(self, argument: str, input_type: Type):
text = Text(**{argument: input_type("test")})
assert hasattr(text, argument)
assert isinstance(getattr(text, argument), list)
assert isinstance(getattr(text, argument), Sequence)
@pytest.mark.parametrize(
"argument,ignore_case,input_value,update_type,result",