mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 01:15:31 +00:00
Merge remote-tracking branch 'origin/dev-3.x' into dev-3.x
This commit is contained in:
commit
da24ca6b07
16 changed files with 187 additions and 143 deletions
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar, cast
|
||||
from typing import Callable, Optional, TypeVar, cast
|
||||
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
|
|
@ -60,17 +59,3 @@ class AiohttpSession(BaseSession):
|
|||
async def __aenter__(self) -> AiohttpSession:
|
||||
await self.create_session()
|
||||
return self
|
||||
|
||||
def __deepcopy__(self: T, memo: Optional[Dict[int, Any]] = None) -> T:
|
||||
if memo is None: # pragma: no cover
|
||||
# This block was never be called
|
||||
memo = {}
|
||||
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for key, value in self.__dict__.items():
|
||||
# aiohttp ClientSession cannot be copied.
|
||||
copied_value = copy.deepcopy(value, memo=memo) if key != "_session" else None
|
||||
setattr(result, key, copied_value)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
|
|
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
|
|||
SyncFilter = Callable[[Any], Any]
|
||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||
HandlerType = Union[CallbackType, BaseHandler]
|
||||
HandlerType = Union[FilterType, BaseHandler]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
|
|||
@dataclass
|
||||
class HandlerObject(CallableMixin):
|
||||
callback: HandlerType
|
||||
filters: List[FilterObject]
|
||||
filters: Optional[List[FilterObject]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super(HandlerObject, self).__post_init__()
|
||||
|
|
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
|
|||
self.awaitable = True
|
||||
|
||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||
if not self.filters:
|
||||
return True, {}
|
||||
for event_filter in self.filters:
|
||||
check = await event_filter.call(*args, **kwargs)
|
||||
if not check:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
|
|
@ -34,15 +34,11 @@ class EventObserver:
|
|||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
|
||||
def register(self, callback: HandlerType) -> HandlerType:
|
||||
"""
|
||||
Register callback with filters
|
||||
"""
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
|
||||
)
|
||||
)
|
||||
self.handlers.append(HandlerObject(callback=callback))
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
|
|
@ -51,22 +47,18 @@ class EventObserver:
|
|||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
kwargs_copy = copy.copy(kwargs)
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
if result:
|
||||
kwargs_copy.update(data)
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs_copy)
|
||||
yield await handler.call(*args, **kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
|
||||
def __call__(self) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback, *args)
|
||||
self.register(callback)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
|
@ -148,15 +140,27 @@ class TelegramEventObserver(EventObserver):
|
|||
Register event handler
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
return super().register(callback, *filters, *resolved_filters)
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
||||
)
|
||||
)
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||
yield result
|
||||
for handler in self.handlers:
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
if result:
|
||||
kwargs.update(data)
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
break
|
||||
|
||||
def __call__(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class BaseFilter(ABC, BaseModel):
|
|||
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
|
||||
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
|
||||
|
||||
__call__: Any
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from __future__ import annotations
|
|||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Match, Optional, Pattern, Union
|
||||
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.api.types import Message
|
||||
|
|
@ -12,11 +14,19 @@ CommandPatterType = Union[str, re.Pattern] # type: ignore
|
|||
|
||||
|
||||
class Command(BaseFilter):
|
||||
commands: List[CommandPatterType]
|
||||
commands: Union[Sequence[CommandPatterType], CommandPatterType]
|
||||
commands_prefix: str = "/"
|
||||
commands_ignore_case: bool = False
|
||||
commands_ignore_mention: bool = False
|
||||
|
||||
@validator("commands", always=True)
|
||||
def _validate_commands(
|
||||
cls, value: Union[Sequence[CommandPatterType], CommandPatterType]
|
||||
) -> Sequence[CommandPatterType]:
|
||||
if isinstance(value, (str, re.Pattern)):
|
||||
value = [value]
|
||||
return value
|
||||
|
||||
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
||||
if not message.text:
|
||||
return False
|
||||
|
|
@ -50,7 +60,7 @@ class Command(BaseFilter):
|
|||
return False
|
||||
|
||||
# Validate command
|
||||
for allowed_command in self.commands:
|
||||
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
|
||||
# Command can be presented as regexp pattern or raw string
|
||||
# then need to validate that in different ways
|
||||
if isinstance(allowed_command, Pattern): # Regexp
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
|
|
@ -8,12 +8,16 @@ from .base import BaseFilter
|
|||
|
||||
|
||||
class ContentTypesFilter(BaseFilter):
|
||||
content_types: Optional[List[str]] = None
|
||||
content_types: Optional[Union[Sequence[str], str]] = None
|
||||
|
||||
@validator("content_types", always=True)
|
||||
def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]:
|
||||
@validator("content_types")
|
||||
def _validate_content_types(
|
||||
cls, value: Optional[Union[Sequence[str], str]]
|
||||
) -> Optional[Sequence[str]]:
|
||||
if not value:
|
||||
value = [ContentType.TEXT]
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
allowed_content_types = set(ContentType.all())
|
||||
bad_content_types = set(value) - allowed_content_types
|
||||
if bad_content_types:
|
||||
|
|
@ -23,5 +27,5 @@ class ContentTypesFilter(BaseFilter):
|
|||
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
|
||||
if not self.content_types: # pragma: no cover
|
||||
# Is impossible but needed for valid typechecking
|
||||
return False
|
||||
self.content_types = [ContentType.TEXT]
|
||||
return ContentType.ANY in self.content_types or message.content_type in self.content_types
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
|
||||
from aiogram.dispatcher.filters import BaseFilter
|
||||
|
||||
TextType = str
|
||||
|
||||
|
||||
class Text(BaseFilter):
|
||||
"""
|
||||
|
|
@ -12,14 +14,14 @@ class Text(BaseFilter):
|
|||
InlineQuery or Poll question.
|
||||
"""
|
||||
|
||||
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
||||
text: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
text_contains: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
text_startswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
text_endswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
text_ignore_case: bool = False
|
||||
|
||||
@root_validator
|
||||
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Validate that only one text filter type is presented
|
||||
used_args = set(
|
||||
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
|
|||
## Specification
|
||||
| Argument | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `commands` | `#!python3 List[CommandPatterType]` | List of commands (string or compiled regexp patterns) |
|
||||
| `commands` | `#!python3 Union[Sequence[Union[str, re.Pattern]], Union[str, re.Pattern]]` | List of commands (string or compiled regexp patterns) |
|
||||
| `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). |
|
||||
| `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) |
|
||||
| `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) |
|
||||
|
|
@ -15,7 +15,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
|
|||
|
||||
## Usage
|
||||
|
||||
1. Filter single variant of commands: `#!python3 Command(commands=["start"])`
|
||||
1. Filter single variant of commands: `#!python3 Command(commands=["start"])` or `#!python3 Command(commands="start")`
|
||||
1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])`
|
||||
1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])`
|
||||
1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)`
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ Or used from filters factory by passing corresponding arguments to handler regis
|
|||
|
||||
| Argument | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `content_types` | `#!python3 Optional[List[str]]` | List of allowed content types |
|
||||
| `content_types` | `#!python3 Optional[Union[Sequence[str], str]]` | List of allowed content types |
|
||||
|
||||
## Usage
|
||||
|
||||
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])`
|
||||
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` or `#!python3 ContentTypesFilter(content_types="sticker")`
|
||||
1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])`
|
||||
1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])`
|
||||
1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])`
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ Or used from filters factory by passing corresponding arguments to handler regis
|
|||
|
||||
| Argument | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
|
||||
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
|
||||
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
|
||||
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
|
||||
| `text` | `#!python3 Optional[Union[Sequence[str], str]]` | Text equals value or one of values |
|
||||
| `text_contains` | `#!python3 Optional[Union[Sequence[str], str]]` | Text contains value or one of values |
|
||||
| `text_startswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text starts with value or one of values |
|
||||
| `text_endswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text ends with value or one of values |
|
||||
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
|
||||
|
||||
!!! warning
|
||||
|
|
|
|||
|
|
@ -14,12 +14,11 @@ Reference: `#!python3 aiogram.dispatcher.event.observer.EventObserver`
|
|||
That is base observer for all events.
|
||||
|
||||
### Base registering method
|
||||
Method: `<observer>.register(callback, filter1, filter2, ...)`
|
||||
Method: `<observer>.register()`
|
||||
|
||||
| Argument | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
||||
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
|
||||
|
||||
Will return original callback.
|
||||
|
||||
|
|
@ -28,14 +27,15 @@ Will return original callback.
|
|||
|
||||
Usage:
|
||||
```python3
|
||||
@<observer>(filter1, filter2, ...)
|
||||
@<observer>()
|
||||
async def handler(*args, **kwargs):
|
||||
pass
|
||||
```
|
||||
|
||||
## TelegramEventObserver
|
||||
Is subclass of [EventObserver](#eventobserver) with some differences.
|
||||
In this handler can be bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
|
||||
Here you can register handler with filters or bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
|
||||
This observer will stops event propagation when first handler is pass.
|
||||
|
||||
### Registering bound filters
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ Bound filter should be subclass of [BaseFilter](filters/index.md)
|
|||
`#!python3 <observer>.bind_filter(MyFilter)`
|
||||
|
||||
### Registering handlers
|
||||
Method: `EventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
|
||||
Method: `TelegramEventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
|
||||
In this method is added bound filters keywords interface.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|
|
@ -52,3 +52,13 @@ In this method is added bound filters keywords interface.
|
|||
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
||||
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
|
||||
| `**bound_filters` | `#!python3 Any` | Bound filters |
|
||||
|
||||
|
||||
### Decorator-style registering event handler with filters
|
||||
|
||||
Usage:
|
||||
```python3
|
||||
@<observer>(filter1, filter2, ...)
|
||||
async def handler(*args, **kwargs):
|
||||
pass
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import copy
|
||||
from typing import AsyncContextManager
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -123,11 +122,3 @@ class TestAiohttpSession:
|
|||
assert session == ctx
|
||||
mocked_close.awaited_once()
|
||||
mocked_create_session.awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepcopy(self):
|
||||
# Session should be copied without aiohttp.ClientSession
|
||||
async with AiohttpSession() as session:
|
||||
cloned_session = copy.deepcopy(session)
|
||||
assert cloned_session != session
|
||||
assert cloned_session._session is None
|
||||
|
|
|
|||
|
|
@ -39,68 +39,38 @@ class MyFilter3(MyFilter1):
|
|||
|
||||
|
||||
class TestEventObserver:
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
pytest.param(5, my_handler, []),
|
||||
pytest.param(3, my_handler, [lambda event: True]),
|
||||
pytest.param(
|
||||
2,
|
||||
my_handler,
|
||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_register_filters(self, count, handler, filters):
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer.register(wrapped_handler, *filters)
|
||||
observer.register(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
pytest.param(5, my_handler, []),
|
||||
pytest.param(3, my_handler, [lambda event: True]),
|
||||
pytest.param(
|
||||
2,
|
||||
my_handler,
|
||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_register_filters_via_decorator(self, count, handler, filters):
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters_via_decorator(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer(*filters)(wrapped_handler)
|
||||
observer()(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_rejected(self):
|
||||
observer = EventObserver()
|
||||
observer.register(my_handler, lambda event: False)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == []
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_accepted_bool(self):
|
||||
observer = EventObserver()
|
||||
observer.register(my_handler, lambda event: True)
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42]
|
||||
|
|
@ -108,23 +78,12 @@ class TestEventObserver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_trigger_with_skip(self):
|
||||
observer = EventObserver()
|
||||
observer.register(skip_my_handler, lambda event: True)
|
||||
observer.register(my_handler, lambda event: False)
|
||||
observer.register(my_handler, lambda event: True)
|
||||
observer.register(skip_my_handler)
|
||||
observer.register(my_handler)
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_right_context_in_handlers(self):
|
||||
observer = EventObserver()
|
||||
observer.register(
|
||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [((42,), {"b": 2})]
|
||||
assert results == [42, 42]
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
|
|
@ -144,9 +103,9 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter in event_observer.filters
|
||||
|
||||
def test_resolve_filters_chain(self):
|
||||
router1 = Router()
|
||||
router2 = Router()
|
||||
router3 = Router()
|
||||
router1 = Router(use_builtin_filters=False)
|
||||
router2 = Router(use_builtin_filters=False)
|
||||
router3 = Router(use_builtin_filters=False)
|
||||
router1.include_router(router2)
|
||||
router2.include_router(router3)
|
||||
|
||||
|
|
@ -168,7 +127,7 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter3 in filters_chain3
|
||||
|
||||
def test_resolve_filters(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
|
|
@ -189,7 +148,7 @@ class TestTelegramEventObserver:
|
|||
assert observer.resolve_filters({"test": ...})
|
||||
|
||||
def test_register(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
|
|
@ -214,7 +173,7 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter1(test="PASS") in callbacks
|
||||
|
||||
def test_register_decorator(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
|
||||
@observer()
|
||||
|
|
@ -226,7 +185,7 @@ class TestTelegramEventObserver:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.register(my_handler, test="ok")
|
||||
|
|
@ -241,3 +200,38 @@ class TestTelegramEventObserver:
|
|||
|
||||
results = [result async for result in observer.trigger(message)]
|
||||
assert results == [message]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
[5, my_handler, []],
|
||||
[3, my_handler, [lambda event: True]],
|
||||
[2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]],
|
||||
),
|
||||
)
|
||||
def test_register_filters_via_decorator(self, count, handler, filters):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer(*filters)(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
|
||||
#
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_right_context_in_handlers(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.register(
|
||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [((42,), {"b": 2})]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ from tests.mocked_bot import MockedBot
|
|||
|
||||
|
||||
class TestCommandFilter:
|
||||
def test_convert_to_list(self):
|
||||
cmd = Command(commands="start")
|
||||
assert cmd.commands
|
||||
assert isinstance(cmd.commands, list)
|
||||
assert cmd.commands[0] == "start"
|
||||
assert cmd == Command(commands=["start"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_command(self, bot: MockedBot):
|
||||
# TODO: parametrize
|
||||
|
|
|
|||
|
|
@ -1,18 +1,37 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiogram.api.types import ContentType, Message
|
||||
from aiogram.dispatcher.filters import ContentTypesFilter
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimalMessage:
|
||||
content_type: str
|
||||
|
||||
|
||||
class TestContentTypesFilter:
|
||||
def test_validator_empty(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_empty(self):
|
||||
filter_ = ContentTypesFilter()
|
||||
assert filter_.content_types == ["text"]
|
||||
assert not filter_.content_types
|
||||
await filter_(cast(Message, MinimalMessage(ContentType.TEXT)))
|
||||
assert filter_.content_types == [ContentType.TEXT]
|
||||
|
||||
def test_validator_empty_list(self):
|
||||
filter_ = ContentTypesFilter(content_types=[])
|
||||
assert filter_.content_types == ["text"]
|
||||
|
||||
def test_convert_to_list(self):
|
||||
filter_ = ContentTypesFilter(content_types="text")
|
||||
assert filter_.content_types
|
||||
assert isinstance(filter_.content_types, list)
|
||||
assert filter_.content_types[0] == "text"
|
||||
assert filter_ == ContentTypesFilter(content_types=["text"])
|
||||
|
||||
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
|
||||
def test_validator_with_values(self, values):
|
||||
filter_ = ContentTypesFilter(content_types=values)
|
||||
|
|
@ -22,3 +41,19 @@ class TestContentTypesFilter:
|
|||
def test_validator_with_bad_values(self, values):
|
||||
with pytest.raises(ValidationError):
|
||||
ContentTypesFilter(content_types=values)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values,content_type,result",
|
||||
[
|
||||
[[], ContentType.TEXT, True],
|
||||
[[ContentType.TEXT], ContentType.TEXT, True],
|
||||
[[ContentType.PHOTO], ContentType.TEXT, False],
|
||||
[[ContentType.ANY], ContentType.TEXT, True],
|
||||
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_call(self, values, content_type, result):
|
||||
filter_ = ContentTypesFilter(content_types=values)
|
||||
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
from itertools import permutations
|
||||
from typing import Type
|
||||
from typing import Sequence, Type
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -46,11 +46,11 @@ class TestText:
|
|||
@pytest.mark.parametrize(
|
||||
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
|
||||
)
|
||||
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
|
||||
@pytest.mark.parametrize("input_type", [str, list, tuple])
|
||||
def test_validator_convert_to_list(self, argument: str, input_type: Type):
|
||||
text = Text(**{argument: input_type("test")})
|
||||
assert hasattr(text, argument)
|
||||
assert isinstance(getattr(text, argument), list)
|
||||
assert isinstance(getattr(text, argument), Sequence)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"argument,ignore_case,input_value,update_type,result",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue