mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 01:15:31 +00:00
Merge remote-tracking branch 'origin/dev-3.x' into dev-3.x
This commit is contained in:
commit
da24ca6b07
16 changed files with 187 additions and 143 deletions
|
|
@ -1,7 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
from typing import Callable, Optional, TypeVar, cast
|
||||||
from typing import Any, Callable, Dict, Optional, TypeVar, cast
|
|
||||||
|
|
||||||
from aiohttp import ClientSession, FormData
|
from aiohttp import ClientSession, FormData
|
||||||
|
|
||||||
|
|
@ -60,17 +59,3 @@ class AiohttpSession(BaseSession):
|
||||||
async def __aenter__(self) -> AiohttpSession:
|
async def __aenter__(self) -> AiohttpSession:
|
||||||
await self.create_session()
|
await self.create_session()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __deepcopy__(self: T, memo: Optional[Dict[int, Any]] = None) -> T:
|
|
||||||
if memo is None: # pragma: no cover
|
|
||||||
# This block was never be called
|
|
||||||
memo = {}
|
|
||||||
|
|
||||||
cls = self.__class__
|
|
||||||
result = cls.__new__(cls)
|
|
||||||
memo[id(self)] = result
|
|
||||||
for key, value in self.__dict__.items():
|
|
||||||
# aiohttp ClientSession cannot be copied.
|
|
||||||
copied_value = copy.deepcopy(value, memo=memo) if key != "_session" else None
|
|
||||||
setattr(result, key, copied_value)
|
|
||||||
return result
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from aiogram.dispatcher.filters.base import BaseFilter
|
from aiogram.dispatcher.filters.base import BaseFilter
|
||||||
from aiogram.dispatcher.handler.base import BaseHandler
|
from aiogram.dispatcher.handler.base import BaseHandler
|
||||||
|
|
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
|
||||||
SyncFilter = Callable[[Any], Any]
|
SyncFilter = Callable[[Any], Any]
|
||||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||||
HandlerType = Union[CallbackType, BaseHandler]
|
HandlerType = Union[FilterType, BaseHandler]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class HandlerObject(CallableMixin):
|
class HandlerObject(CallableMixin):
|
||||||
callback: HandlerType
|
callback: HandlerType
|
||||||
filters: List[FilterObject]
|
filters: Optional[List[FilterObject]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super(HandlerObject, self).__post_init__()
|
super(HandlerObject, self).__post_init__()
|
||||||
|
|
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
|
||||||
self.awaitable = True
|
self.awaitable = True
|
||||||
|
|
||||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||||
|
if not self.filters:
|
||||||
|
return True, {}
|
||||||
for event_filter in self.filters:
|
for event_filter in self.filters:
|
||||||
check = await event_filter.call(*args, **kwargs)
|
check = await event_filter.call(*args, **kwargs)
|
||||||
if not check:
|
if not check:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -34,15 +34,11 @@ class EventObserver:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.handlers: List[HandlerObject] = []
|
self.handlers: List[HandlerObject] = []
|
||||||
|
|
||||||
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
|
def register(self, callback: HandlerType) -> HandlerType:
|
||||||
"""
|
"""
|
||||||
Register callback with filters
|
Register callback with filters
|
||||||
"""
|
"""
|
||||||
self.handlers.append(
|
self.handlers.append(HandlerObject(callback=callback))
|
||||||
HandlerObject(
|
|
||||||
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||||
|
|
@ -51,22 +47,18 @@ class EventObserver:
|
||||||
Handler will be called when all its filters is pass.
|
Handler will be called when all its filters is pass.
|
||||||
"""
|
"""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
kwargs_copy = copy.copy(kwargs)
|
|
||||||
result, data = await handler.check(*args, **kwargs)
|
|
||||||
if result:
|
|
||||||
kwargs_copy.update(data)
|
|
||||||
try:
|
try:
|
||||||
yield await handler.call(*args, **kwargs_copy)
|
yield await handler.call(*args, **kwargs)
|
||||||
except SkipHandler:
|
except SkipHandler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
|
def __call__(self) -> Callable[[CallbackType], CallbackType]:
|
||||||
"""
|
"""
|
||||||
Decorator for registering event handlers
|
Decorator for registering event handlers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(callback: CallbackType) -> CallbackType:
|
def wrapper(callback: CallbackType) -> CallbackType:
|
||||||
self.register(callback, *args)
|
self.register(callback)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
@ -148,15 +140,27 @@ class TelegramEventObserver(EventObserver):
|
||||||
Register event handler
|
Register event handler
|
||||||
"""
|
"""
|
||||||
resolved_filters = self.resolve_filters(bound_filters)
|
resolved_filters = self.resolve_filters(bound_filters)
|
||||||
return super().register(callback, *filters, *resolved_filters)
|
self.handlers.append(
|
||||||
|
HandlerObject(
|
||||||
|
callback=callback,
|
||||||
|
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return callback
|
||||||
|
|
||||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||||
"""
|
"""
|
||||||
Propagate event to handlers and stops propagation on first match.
|
Propagate event to handlers and stops propagation on first match.
|
||||||
Handler will be called when all its filters is pass.
|
Handler will be called when all its filters is pass.
|
||||||
"""
|
"""
|
||||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
for handler in self.handlers:
|
||||||
yield result
|
result, data = await handler.check(*args, **kwargs)
|
||||||
|
if result:
|
||||||
|
kwargs.update(data)
|
||||||
|
try:
|
||||||
|
yield await handler.call(*args, **kwargs)
|
||||||
|
except SkipHandler:
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class BaseFilter(ABC, BaseModel):
|
||||||
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
|
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
|
||||||
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
|
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
|
||||||
|
|
||||||
__call__: Any
|
pass
|
||||||
else: # pragma: no cover
|
else: # pragma: no cover
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Match, Optional, Pattern, Union
|
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
|
||||||
|
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.api.types import Message
|
from aiogram.api.types import Message
|
||||||
|
|
@ -12,11 +14,19 @@ CommandPatterType = Union[str, re.Pattern] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseFilter):
|
class Command(BaseFilter):
|
||||||
commands: List[CommandPatterType]
|
commands: Union[Sequence[CommandPatterType], CommandPatterType]
|
||||||
commands_prefix: str = "/"
|
commands_prefix: str = "/"
|
||||||
commands_ignore_case: bool = False
|
commands_ignore_case: bool = False
|
||||||
commands_ignore_mention: bool = False
|
commands_ignore_mention: bool = False
|
||||||
|
|
||||||
|
@validator("commands", always=True)
|
||||||
|
def _validate_commands(
|
||||||
|
cls, value: Union[Sequence[CommandPatterType], CommandPatterType]
|
||||||
|
) -> Sequence[CommandPatterType]:
|
||||||
|
if isinstance(value, (str, re.Pattern)):
|
||||||
|
value = [value]
|
||||||
|
return value
|
||||||
|
|
||||||
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
||||||
if not message.text:
|
if not message.text:
|
||||||
return False
|
return False
|
||||||
|
|
@ -50,7 +60,7 @@ class Command(BaseFilter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate command
|
# Validate command
|
||||||
for allowed_command in self.commands:
|
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
|
||||||
# Command can be presented as regexp pattern or raw string
|
# Command can be presented as regexp pattern or raw string
|
||||||
# then need to validate that in different ways
|
# then need to validate that in different ways
|
||||||
if isinstance(allowed_command, Pattern): # Regexp
|
if isinstance(allowed_command, Pattern): # Regexp
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
|
|
@ -8,12 +8,16 @@ from .base import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class ContentTypesFilter(BaseFilter):
|
class ContentTypesFilter(BaseFilter):
|
||||||
content_types: Optional[List[str]] = None
|
content_types: Optional[Union[Sequence[str], str]] = None
|
||||||
|
|
||||||
@validator("content_types", always=True)
|
@validator("content_types")
|
||||||
def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]:
|
def _validate_content_types(
|
||||||
|
cls, value: Optional[Union[Sequence[str], str]]
|
||||||
|
) -> Optional[Sequence[str]]:
|
||||||
if not value:
|
if not value:
|
||||||
value = [ContentType.TEXT]
|
value = [ContentType.TEXT]
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
allowed_content_types = set(ContentType.all())
|
allowed_content_types = set(ContentType.all())
|
||||||
bad_content_types = set(value) - allowed_content_types
|
bad_content_types = set(value) - allowed_content_types
|
||||||
if bad_content_types:
|
if bad_content_types:
|
||||||
|
|
@ -23,5 +27,5 @@ class ContentTypesFilter(BaseFilter):
|
||||||
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
|
||||||
if not self.content_types: # pragma: no cover
|
if not self.content_types: # pragma: no cover
|
||||||
# Is impossible but needed for valid typechecking
|
# Is impossible but needed for valid typechecking
|
||||||
return False
|
self.content_types = [ContentType.TEXT]
|
||||||
return ContentType.ANY in self.content_types or message.content_type in self.content_types
|
return ContentType.ANY in self.content_types or message.content_type in self.content_types
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
|
from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll
|
||||||
from aiogram.dispatcher.filters import BaseFilter
|
from aiogram.dispatcher.filters import BaseFilter
|
||||||
|
|
||||||
|
TextType = str
|
||||||
|
|
||||||
|
|
||||||
class Text(BaseFilter):
|
class Text(BaseFilter):
|
||||||
"""
|
"""
|
||||||
|
|
@ -12,14 +14,14 @@ class Text(BaseFilter):
|
||||||
InlineQuery or Poll question.
|
InlineQuery or Poll question.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
text: Optional[Union[Sequence[TextType], TextType]] = None
|
||||||
text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
text_contains: Optional[Union[Sequence[TextType], TextType]] = None
|
||||||
text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
text_startswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||||
text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None
|
text_endswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||||
text_ignore_case: bool = False
|
text_ignore_case: bool = False
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
# Validate that only one text filter type is presented
|
# Validate that only one text filter type is presented
|
||||||
used_args = set(
|
used_args = set(
|
||||||
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
|
||||||
## Specification
|
## Specification
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `commands` | `#!python3 List[CommandPatterType]` | List of commands (string or compiled regexp patterns) |
|
| `commands` | `#!python3 Union[Sequence[Union[str, re.Pattern]], Union[str, re.Pattern]]` | List of commands (string or compiled regexp patterns) |
|
||||||
| `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). |
|
| `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). |
|
||||||
| `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) |
|
| `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) |
|
||||||
| `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) |
|
| `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) |
|
||||||
|
|
@ -15,7 +15,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. Filter single variant of commands: `#!python3 Command(commands=["start"])`
|
1. Filter single variant of commands: `#!python3 Command(commands=["start"])` or `#!python3 Command(commands="start")`
|
||||||
1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])`
|
1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])`
|
||||||
1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])`
|
1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])`
|
||||||
1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)`
|
1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)`
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,11 @@ Or used from filters factory by passing corresponding arguments to handler regis
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `content_types` | `#!python3 Optional[List[str]]` | List of allowed content types |
|
| `content_types` | `#!python3 Optional[Union[Sequence[str], str]]` | List of allowed content types |
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])`
|
1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` or `#!python3 ContentTypesFilter(content_types="sticker")`
|
||||||
1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])`
|
1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])`
|
||||||
1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])`
|
1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])`
|
||||||
1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])`
|
1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])`
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ Or used from filters factory by passing corresponding arguments to handler regis
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values |
|
| `text` | `#!python3 Optional[Union[Sequence[str], str]]` | Text equals value or one of values |
|
||||||
| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values |
|
| `text_contains` | `#!python3 Optional[Union[Sequence[str], str]]` | Text contains value or one of values |
|
||||||
| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values |
|
| `text_startswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text starts with value or one of values |
|
||||||
| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values |
|
| `text_endswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text ends with value or one of values |
|
||||||
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
|
| `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) |
|
||||||
|
|
||||||
!!! warning
|
!!! warning
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,11 @@ Reference: `#!python3 aiogram.dispatcher.event.observer.EventObserver`
|
||||||
That is base observer for all events.
|
That is base observer for all events.
|
||||||
|
|
||||||
### Base registering method
|
### Base registering method
|
||||||
Method: `<observer>.register(callback, filter1, filter2, ...)`
|
Method: `<observer>.register()`
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
||||||
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
|
|
||||||
|
|
||||||
Will return original callback.
|
Will return original callback.
|
||||||
|
|
||||||
|
|
@ -28,14 +27,15 @@ Will return original callback.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```python3
|
```python3
|
||||||
@<observer>(filter1, filter2, ...)
|
@<observer>()
|
||||||
async def handler(*args, **kwargs):
|
async def handler(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
## TelegramEventObserver
|
## TelegramEventObserver
|
||||||
Is subclass of [EventObserver](#eventobserver) with some differences.
|
Is subclass of [EventObserver](#eventobserver) with some differences.
|
||||||
In this handler can be bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
|
Here you can register handler with filters or bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers.
|
||||||
|
This observer will stops event propagation when first handler is pass.
|
||||||
|
|
||||||
### Registering bound filters
|
### Registering bound filters
|
||||||
|
|
||||||
|
|
@ -44,7 +44,7 @@ Bound filter should be subclass of [BaseFilter](filters/index.md)
|
||||||
`#!python3 <observer>.bind_filter(MyFilter)`
|
`#!python3 <observer>.bind_filter(MyFilter)`
|
||||||
|
|
||||||
### Registering handlers
|
### Registering handlers
|
||||||
Method: `EventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
|
Method: `TelegramEventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)`
|
||||||
In this method is added bound filters keywords interface.
|
In this method is added bound filters keywords interface.
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
|
|
@ -52,3 +52,13 @@ In this method is added bound filters keywords interface.
|
||||||
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
| `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler |
|
||||||
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
|
| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set |
|
||||||
| `**bound_filters` | `#!python3 Any` | Bound filters |
|
| `**bound_filters` | `#!python3 Any` | Bound filters |
|
||||||
|
|
||||||
|
|
||||||
|
### Decorator-style registering event handler with filters
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python3
|
||||||
|
@<observer>(filter1, filter2, ...)
|
||||||
|
async def handler(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import copy
|
|
||||||
from typing import AsyncContextManager
|
from typing import AsyncContextManager
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -123,11 +122,3 @@ class TestAiohttpSession:
|
||||||
assert session == ctx
|
assert session == ctx
|
||||||
mocked_close.awaited_once()
|
mocked_close.awaited_once()
|
||||||
mocked_create_session.awaited_once()
|
mocked_create_session.awaited_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_deepcopy(self):
|
|
||||||
# Session should be copied without aiohttp.ClientSession
|
|
||||||
async with AiohttpSession() as session:
|
|
||||||
cloned_session = copy.deepcopy(session)
|
|
||||||
assert cloned_session != session
|
|
||||||
assert cloned_session._session is None
|
|
||||||
|
|
|
||||||
|
|
@ -39,68 +39,38 @@ class MyFilter3(MyFilter1):
|
||||||
|
|
||||||
|
|
||||||
class TestEventObserver:
|
class TestEventObserver:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||||
"count,handler,filters",
|
def test_register_filters(self, count, handler):
|
||||||
(
|
|
||||||
pytest.param(5, my_handler, []),
|
|
||||||
pytest.param(3, my_handler, [lambda event: True]),
|
|
||||||
pytest.param(
|
|
||||||
2,
|
|
||||||
my_handler,
|
|
||||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def test_register_filters(self, count, handler, filters):
|
|
||||||
observer = EventObserver()
|
observer = EventObserver()
|
||||||
|
|
||||||
for index in range(count):
|
for index in range(count):
|
||||||
wrapped_handler = functools.partial(handler, index=index)
|
wrapped_handler = functools.partial(handler, index=index)
|
||||||
observer.register(wrapped_handler, *filters)
|
observer.register(wrapped_handler)
|
||||||
registered_handler = observer.handlers[index]
|
registered_handler = observer.handlers[index]
|
||||||
|
|
||||||
assert len(observer.handlers) == index + 1
|
assert len(observer.handlers) == index + 1
|
||||||
assert isinstance(registered_handler, HandlerObject)
|
assert isinstance(registered_handler, HandlerObject)
|
||||||
assert registered_handler.callback == wrapped_handler
|
assert registered_handler.callback == wrapped_handler
|
||||||
assert len(registered_handler.filters) == len(filters)
|
assert not registered_handler.filters
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||||
"count,handler,filters",
|
def test_register_filters_via_decorator(self, count, handler):
|
||||||
(
|
|
||||||
pytest.param(5, my_handler, []),
|
|
||||||
pytest.param(3, my_handler, [lambda event: True]),
|
|
||||||
pytest.param(
|
|
||||||
2,
|
|
||||||
my_handler,
|
|
||||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def test_register_filters_via_decorator(self, count, handler, filters):
|
|
||||||
observer = EventObserver()
|
observer = EventObserver()
|
||||||
|
|
||||||
for index in range(count):
|
for index in range(count):
|
||||||
wrapped_handler = functools.partial(handler, index=index)
|
wrapped_handler = functools.partial(handler, index=index)
|
||||||
observer(*filters)(wrapped_handler)
|
observer()(wrapped_handler)
|
||||||
registered_handler = observer.handlers[index]
|
registered_handler = observer.handlers[index]
|
||||||
|
|
||||||
assert len(observer.handlers) == index + 1
|
assert len(observer.handlers) == index + 1
|
||||||
assert isinstance(registered_handler, HandlerObject)
|
assert isinstance(registered_handler, HandlerObject)
|
||||||
assert registered_handler.callback == wrapped_handler
|
assert registered_handler.callback == wrapped_handler
|
||||||
assert len(registered_handler.filters) == len(filters)
|
assert not registered_handler.filters
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_rejected(self):
|
|
||||||
observer = EventObserver()
|
|
||||||
observer.register(my_handler, lambda event: False)
|
|
||||||
|
|
||||||
results = [result async for result in observer.trigger(42)]
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_accepted_bool(self):
|
async def test_trigger_accepted_bool(self):
|
||||||
observer = EventObserver()
|
observer = EventObserver()
|
||||||
observer.register(my_handler, lambda event: True)
|
observer.register(my_handler)
|
||||||
|
|
||||||
results = [result async for result in observer.trigger(42)]
|
results = [result async for result in observer.trigger(42)]
|
||||||
assert results == [42]
|
assert results == [42]
|
||||||
|
|
@ -108,23 +78,12 @@ class TestEventObserver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_with_skip(self):
|
async def test_trigger_with_skip(self):
|
||||||
observer = EventObserver()
|
observer = EventObserver()
|
||||||
observer.register(skip_my_handler, lambda event: True)
|
observer.register(skip_my_handler)
|
||||||
observer.register(my_handler, lambda event: False)
|
observer.register(my_handler)
|
||||||
observer.register(my_handler, lambda event: True)
|
observer.register(my_handler)
|
||||||
|
|
||||||
results = [result async for result in observer.trigger(42)]
|
results = [result async for result in observer.trigger(42)]
|
||||||
assert results == [42]
|
assert results == [42, 42]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_right_context_in_handlers(self):
|
|
||||||
observer = EventObserver()
|
|
||||||
observer.register(
|
|
||||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
|
||||||
) # {"a": 1} should not be in result
|
|
||||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
|
||||||
|
|
||||||
results = [result async for result in observer.trigger(42)]
|
|
||||||
assert results == [((42,), {"b": 2})]
|
|
||||||
|
|
||||||
|
|
||||||
class TestTelegramEventObserver:
|
class TestTelegramEventObserver:
|
||||||
|
|
@ -144,9 +103,9 @@ class TestTelegramEventObserver:
|
||||||
assert MyFilter in event_observer.filters
|
assert MyFilter in event_observer.filters
|
||||||
|
|
||||||
def test_resolve_filters_chain(self):
|
def test_resolve_filters_chain(self):
|
||||||
router1 = Router()
|
router1 = Router(use_builtin_filters=False)
|
||||||
router2 = Router()
|
router2 = Router(use_builtin_filters=False)
|
||||||
router3 = Router()
|
router3 = Router(use_builtin_filters=False)
|
||||||
router1.include_router(router2)
|
router1.include_router(router2)
|
||||||
router2.include_router(router3)
|
router2.include_router(router3)
|
||||||
|
|
||||||
|
|
@ -168,7 +127,7 @@ class TestTelegramEventObserver:
|
||||||
assert MyFilter3 in filters_chain3
|
assert MyFilter3 in filters_chain3
|
||||||
|
|
||||||
def test_resolve_filters(self):
|
def test_resolve_filters(self):
|
||||||
router = Router()
|
router = Router(use_builtin_filters=False)
|
||||||
observer = router.message_handler
|
observer = router.message_handler
|
||||||
observer.bind_filter(MyFilter1)
|
observer.bind_filter(MyFilter1)
|
||||||
|
|
||||||
|
|
@ -189,7 +148,7 @@ class TestTelegramEventObserver:
|
||||||
assert observer.resolve_filters({"test": ...})
|
assert observer.resolve_filters({"test": ...})
|
||||||
|
|
||||||
def test_register(self):
|
def test_register(self):
|
||||||
router = Router()
|
router = Router(use_builtin_filters=False)
|
||||||
observer = router.message_handler
|
observer = router.message_handler
|
||||||
observer.bind_filter(MyFilter1)
|
observer.bind_filter(MyFilter1)
|
||||||
|
|
||||||
|
|
@ -214,7 +173,7 @@ class TestTelegramEventObserver:
|
||||||
assert MyFilter1(test="PASS") in callbacks
|
assert MyFilter1(test="PASS") in callbacks
|
||||||
|
|
||||||
def test_register_decorator(self):
|
def test_register_decorator(self):
|
||||||
router = Router()
|
router = Router(use_builtin_filters=False)
|
||||||
observer = router.message_handler
|
observer = router.message_handler
|
||||||
|
|
||||||
@observer()
|
@observer()
|
||||||
|
|
@ -226,7 +185,7 @@ class TestTelegramEventObserver:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger(self):
|
async def test_trigger(self):
|
||||||
router = Router()
|
router = Router(use_builtin_filters=False)
|
||||||
observer = router.message_handler
|
observer = router.message_handler
|
||||||
observer.bind_filter(MyFilter1)
|
observer.bind_filter(MyFilter1)
|
||||||
observer.register(my_handler, test="ok")
|
observer.register(my_handler, test="ok")
|
||||||
|
|
@ -241,3 +200,38 @@ class TestTelegramEventObserver:
|
||||||
|
|
||||||
results = [result async for result in observer.trigger(message)]
|
results = [result async for result in observer.trigger(message)]
|
||||||
assert results == [message]
|
assert results == [message]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"count,handler,filters",
|
||||||
|
(
|
||||||
|
[5, my_handler, []],
|
||||||
|
[3, my_handler, [lambda event: True]],
|
||||||
|
[2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_register_filters_via_decorator(self, count, handler, filters):
|
||||||
|
router = Router(use_builtin_filters=False)
|
||||||
|
observer = router.message_handler
|
||||||
|
|
||||||
|
for index in range(count):
|
||||||
|
wrapped_handler = functools.partial(handler, index=index)
|
||||||
|
observer(*filters)(wrapped_handler)
|
||||||
|
registered_handler = observer.handlers[index]
|
||||||
|
|
||||||
|
assert len(observer.handlers) == index + 1
|
||||||
|
assert isinstance(registered_handler, HandlerObject)
|
||||||
|
assert registered_handler.callback == wrapped_handler
|
||||||
|
assert len(registered_handler.filters) == len(filters)
|
||||||
|
|
||||||
|
#
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_right_context_in_handlers(self):
|
||||||
|
router = Router(use_builtin_filters=False)
|
||||||
|
observer = router.message_handler
|
||||||
|
observer.register(
|
||||||
|
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||||
|
) # {"a": 1} should not be in result
|
||||||
|
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||||
|
|
||||||
|
results = [result async for result in observer.trigger(42)]
|
||||||
|
assert results == [((42,), {"b": 2})]
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,13 @@ from tests.mocked_bot import MockedBot
|
||||||
|
|
||||||
|
|
||||||
class TestCommandFilter:
|
class TestCommandFilter:
|
||||||
|
def test_convert_to_list(self):
|
||||||
|
cmd = Command(commands="start")
|
||||||
|
assert cmd.commands
|
||||||
|
assert isinstance(cmd.commands, list)
|
||||||
|
assert cmd.commands[0] == "start"
|
||||||
|
assert cmd == Command(commands=["start"])
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_command(self, bot: MockedBot):
|
async def test_parse_command(self, bot: MockedBot):
|
||||||
# TODO: parametrize
|
# TODO: parametrize
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,37 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from aiogram.api.types import ContentType, Message
|
||||||
from aiogram.dispatcher.filters import ContentTypesFilter
|
from aiogram.dispatcher.filters import ContentTypesFilter
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MinimalMessage:
|
||||||
|
content_type: str
|
||||||
|
|
||||||
|
|
||||||
class TestContentTypesFilter:
|
class TestContentTypesFilter:
|
||||||
def test_validator_empty(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_validator_empty(self):
|
||||||
filter_ = ContentTypesFilter()
|
filter_ = ContentTypesFilter()
|
||||||
assert filter_.content_types == ["text"]
|
assert not filter_.content_types
|
||||||
|
await filter_(cast(Message, MinimalMessage(ContentType.TEXT)))
|
||||||
|
assert filter_.content_types == [ContentType.TEXT]
|
||||||
|
|
||||||
def test_validator_empty_list(self):
|
def test_validator_empty_list(self):
|
||||||
filter_ = ContentTypesFilter(content_types=[])
|
filter_ = ContentTypesFilter(content_types=[])
|
||||||
assert filter_.content_types == ["text"]
|
assert filter_.content_types == ["text"]
|
||||||
|
|
||||||
|
def test_convert_to_list(self):
|
||||||
|
filter_ = ContentTypesFilter(content_types="text")
|
||||||
|
assert filter_.content_types
|
||||||
|
assert isinstance(filter_.content_types, list)
|
||||||
|
assert filter_.content_types[0] == "text"
|
||||||
|
assert filter_ == ContentTypesFilter(content_types=["text"])
|
||||||
|
|
||||||
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
|
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
|
||||||
def test_validator_with_values(self, values):
|
def test_validator_with_values(self, values):
|
||||||
filter_ = ContentTypesFilter(content_types=values)
|
filter_ = ContentTypesFilter(content_types=values)
|
||||||
|
|
@ -22,3 +41,19 @@ class TestContentTypesFilter:
|
||||||
def test_validator_with_bad_values(self, values):
|
def test_validator_with_bad_values(self, values):
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
ContentTypesFilter(content_types=values)
|
ContentTypesFilter(content_types=values)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"values,content_type,result",
|
||||||
|
[
|
||||||
|
[[], ContentType.TEXT, True],
|
||||||
|
[[ContentType.TEXT], ContentType.TEXT, True],
|
||||||
|
[[ContentType.PHOTO], ContentType.TEXT, False],
|
||||||
|
[[ContentType.ANY], ContentType.TEXT, True],
|
||||||
|
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||||
|
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call(self, values, content_type, result):
|
||||||
|
filter_ = ContentTypesFilter(content_types=values)
|
||||||
|
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
from typing import Type
|
from typing import Sequence, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
@ -46,11 +46,11 @@ class TestText:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
|
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("input_type", [str, list, tuple, set])
|
@pytest.mark.parametrize("input_type", [str, list, tuple])
|
||||||
def test_validator_convert_to_list(self, argument: str, input_type: Type):
|
def test_validator_convert_to_list(self, argument: str, input_type: Type):
|
||||||
text = Text(**{argument: input_type("test")})
|
text = Text(**{argument: input_type("test")})
|
||||||
assert hasattr(text, argument)
|
assert hasattr(text, argument)
|
||||||
assert isinstance(getattr(text, argument), list)
|
assert isinstance(getattr(text, argument), Sequence)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"argument,ignore_case,input_value,update_type,result",
|
"argument,ignore_case,input_value,update_type,result",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue