diff --git a/aiogram/api/client/session/aiohttp.py b/aiogram/api/client/session/aiohttp.py index 28ca6db0..9ac73eaa 100644 --- a/aiogram/api/client/session/aiohttp.py +++ b/aiogram/api/client/session/aiohttp.py @@ -1,7 +1,6 @@ from __future__ import annotations -import copy -from typing import Any, Callable, Dict, Optional, TypeVar, cast +from typing import Callable, Optional, TypeVar, cast from aiohttp import ClientSession, FormData @@ -60,17 +59,3 @@ class AiohttpSession(BaseSession): async def __aenter__(self) -> AiohttpSession: await self.create_session() return self - - def __deepcopy__(self: T, memo: Optional[Dict[int, Any]] = None) -> T: - if memo is None: # pragma: no cover - # This block was never be called - memo = {} - - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for key, value in self.__dict__.items(): - # aiohttp ClientSession cannot be copied. - copied_value = copy.deepcopy(value, memo=memo) if key != "_session" else None - setattr(result, key, copied_value) - return result diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 42c6202d..52e8c0da 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -1,7 +1,7 @@ import inspect from dataclasses import dataclass, field from functools import partial -from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler @@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]] SyncFilter = Callable[[Any], Any] AsyncFilter = Callable[[Any], Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] -HandlerType = Union[CallbackType, BaseHandler] +HandlerType = Union[FilterType, BaseHandler] @dataclass @@ -47,7 +47,7 @@ class FilterObject(CallableMixin): @dataclass class HandlerObject(CallableMixin): callback: HandlerType - filters: List[FilterObject] + filters: Optional[List[FilterObject]] = None def __post_init__(self): super(HandlerObject, self).__post_init__() @@ -56,6 +56,8 @@ class HandlerObject(CallableMixin): self.awaitable = True async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: + if not self.filters: + return True, {} for event_filter in self.filters: check = await event_filter.call(*args, **kwargs) if not check: diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 93115ab7..93f4aac6 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -1,6 +1,6 @@ from __future__ import annotations -import copy +from itertools import chain from typing import ( TYPE_CHECKING, Any, @@ -34,15 +34,11 @@ class EventObserver: def __init__(self) -> None: self.handlers: List[HandlerObject] = [] - def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType: + def register(self, callback: HandlerType) -> HandlerType: """ Register callback with filters """ - self.handlers.append( - HandlerObject( - callback=callback, filters=[FilterObject(filter_) for filter_ in filters] - ) - ) + self.handlers.append(HandlerObject(callback=callback)) return callback async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: @@ -51,22 +47,18 @@ class EventObserver: Handler will be called when all its filters is pass. """ for handler in self.handlers: - kwargs_copy = copy.copy(kwargs) - result, data = await handler.check(*args, **kwargs) - if result: - kwargs_copy.update(data) - try: - yield await handler.call(*args, **kwargs_copy) - except SkipHandler: - continue + try: + yield await handler.call(*args, **kwargs) + except SkipHandler: + continue - def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]: + def __call__(self) -> Callable[[CallbackType], CallbackType]: """ Decorator for registering event handlers """ def wrapper(callback: CallbackType) -> CallbackType: - self.register(callback, *args) + self.register(callback) return callback return wrapper @@ -148,16 +140,28 @@ class TelegramEventObserver(EventObserver): Register event handler """ resolved_filters = self.resolve_filters(bound_filters) - return super().register(callback, *filters, *resolved_filters) + self.handlers.append( + HandlerObject( + callback=callback, + filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)], + ) + ) + return callback async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: """ Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs): - yield result - break + for handler in self.handlers: + result, data = await handler.check(*args, **kwargs) + if result: + kwargs.update(data) + try: + yield await handler.call(*args, **kwargs) + except SkipHandler: + continue + break def __call__( self, *args: FilterType, **bound_filters: BaseFilter diff --git a/aiogram/dispatcher/filters/base.py b/aiogram/dispatcher/filters/base.py index fabaaf21..c0a6c377 100644 --- a/aiogram/dispatcher/filters/base.py +++ b/aiogram/dispatcher/filters/base.py @@ -10,7 +10,7 @@ class BaseFilter(ABC, BaseModel): # error: Signature of "__call__" incompatible with supertype "BaseFilter" [override] # https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override - __call__: Any + pass else: # pragma: no cover @abstractmethod diff --git a/aiogram/dispatcher/filters/command.py b/aiogram/dispatcher/filters/command.py index c97f79fa..6eef0c3c 100644 --- a/aiogram/dispatcher/filters/command.py +++ b/aiogram/dispatcher/filters/command.py @@ -2,7 +2,9 @@ from __future__ import annotations import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Match, Optional, Pattern, Union +from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast + +from pydantic import validator from aiogram import Bot from aiogram.api.types import Message @@ -12,11 +14,19 @@ CommandPatterType = Union[str, re.Pattern] # type: ignore class Command(BaseFilter): - commands: List[CommandPatterType] + commands: Union[Sequence[CommandPatterType], CommandPatterType] commands_prefix: str = "/" commands_ignore_case: bool = False commands_ignore_mention: bool = False + @validator("commands", always=True) + def _validate_commands( + cls, value: Union[Sequence[CommandPatterType], CommandPatterType] + ) -> Sequence[CommandPatterType]: + if isinstance(value, (str, re.Pattern)): + value = [value] + return value + async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: if not message.text: return False @@ -50,7 +60,7 @@ class Command(BaseFilter): return False # Validate command - for allowed_command in self.commands: + for allowed_command in cast(Sequence[CommandPatterType], self.commands): # Command can be presented as regexp pattern or raw string # then need to validate that in different ways if isinstance(allowed_command, Pattern): # Regexp diff --git a/aiogram/dispatcher/filters/content_types.py b/aiogram/dispatcher/filters/content_types.py index 0f8134e2..1f195663 100644 --- a/aiogram/dispatcher/filters/content_types.py +++ b/aiogram/dispatcher/filters/content_types.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union from pydantic import validator @@ -8,12 +8,16 @@ from .base import BaseFilter class ContentTypesFilter(BaseFilter): - content_types: Optional[List[str]] = None + content_types: Optional[Union[Sequence[str], str]] = None - @validator("content_types", always=True) - def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]: + @validator("content_types") + def _validate_content_types( + cls, value: Optional[Union[Sequence[str], str]] + ) -> Optional[Sequence[str]]: if not value: value = [ContentType.TEXT] + if isinstance(value, str): + value = [value] allowed_content_types = set(ContentType.all()) bad_content_types = set(value) - allowed_content_types if bad_content_types: @@ -23,5 +27,5 @@ class ContentTypesFilter(BaseFilter): async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]: if not self.content_types: # pragma: no cover # Is impossible but needed for valid typechecking - return False + self.content_types = [ContentType.TEXT] return ContentType.ANY in self.content_types or message.content_type in self.content_types diff --git a/aiogram/dispatcher/filters/text.py b/aiogram/dispatcher/filters/text.py index 71586e7e..8a94cd7b 100644 --- a/aiogram/dispatcher/filters/text.py +++ b/aiogram/dispatcher/filters/text.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Union from pydantic import root_validator from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll from aiogram.dispatcher.filters import BaseFilter +TextType = str + class Text(BaseFilter): """ @@ -12,14 +14,14 @@ class Text(BaseFilter): InlineQuery or Poll question. """ - text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text: Optional[Union[Sequence[TextType], TextType]] = None + text_contains: Optional[Union[Sequence[TextType], TextType]] = None + text_startswith: Optional[Union[Sequence[TextType], TextType]] = None + text_endswith: Optional[Union[Sequence[TextType], TextType]] = None text_ignore_case: bool = False @root_validator - def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: # Validate that only one text filter type is presented used_args = set( key for key, value in values.items() if key != "text_ignore_case" and value is not None diff --git a/docs/dispatcher/filters/command.md b/docs/dispatcher/filters/command.md index 5ce340c9..5d8ab771 100644 --- a/docs/dispatcher/filters/command.md +++ b/docs/dispatcher/filters/command.md @@ -7,7 +7,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex ## Specification | Argument | Type | Description | | --- | --- | --- | -| `commands` | `#!python3 List[CommandPatterType]` | List of commands (string or compiled regexp patterns) | +| `commands` | `#!python3 Union[Sequence[Union[str, re.Pattern]], Union[str, re.Pattern]]` | List of commands (string or compiled regexp patterns) | | `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). | | `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) | | `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) | @@ -15,7 +15,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex ## Usage -1. Filter single variant of commands: `#!python3 Command(commands=["start"])` +1. Filter single variant of commands: `#!python3 Command(commands=["start"])` or `#!python3 Command(commands="start")` 1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])` 1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])` 1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)` diff --git a/docs/dispatcher/filters/content_types.md b/docs/dispatcher/filters/content_types.md index 30a635b1..a37006e3 100644 --- a/docs/dispatcher/filters/content_types.md +++ b/docs/dispatcher/filters/content_types.md @@ -19,11 +19,11 @@ Or used from filters factory by passing corresponding arguments to handler regis | Argument | Type | Description | | --- | --- | --- | -| `content_types` | `#!python3 Optional[List[str]]` | List of allowed content types | +| `content_types` | `#!python3 Optional[Union[Sequence[str], str]]` | List of allowed content types | ## Usage -1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` +1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` or `#!python3 ContentTypesFilter(content_types="sticker")` 1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])` 1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])` 1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])` diff --git a/docs/dispatcher/filters/text.md b/docs/dispatcher/filters/text.md index 7c7338aa..0691459b 100644 --- a/docs/dispatcher/filters/text.md +++ b/docs/dispatcher/filters/text.md @@ -18,10 +18,10 @@ Or used from filters factory by passing corresponding arguments to handler regis | Argument | Type | Description | | --- | --- | --- | -| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values | -| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values | -| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values | -| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values | +| `text` | `#!python3 Optional[Union[Sequence[str], str]]` | Text equals value or one of values | +| `text_contains` | `#!python3 Optional[Union[Sequence[str], str]]` | Text contains value or one of values | +| `text_startswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text starts with value or one of values | +| `text_endswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text ends with value or one of values | | `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) | !!! warning diff --git a/docs/dispatcher/observer.md b/docs/dispatcher/observer.md index 222503e9..b98db9b6 100644 --- a/docs/dispatcher/observer.md +++ b/docs/dispatcher/observer.md @@ -14,12 +14,11 @@ Reference: `#!python3 aiogram.dispatcher.event.observer.EventObserver` That is base observer for all events. ### Base registering method -Method: `.register(callback, filter1, filter2, ...)` +Method: `.register()` | Argument | Type | Description | | --- | --- | --- | | `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler | -| `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set | Will return original callback. @@ -28,14 +27,15 @@ Will return original callback. Usage: ```python3 -@(filter1, filter2, ...) +@() async def handler(*args, **kwargs): pass ``` ## TelegramEventObserver Is subclass of [EventObserver](#eventobserver) with some differences. -In this handler can be bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers. +Here you can register handler with filters or bounded filters which can be used as keyword arguments instead of writing full references when you register new handlers. +This observer will stops event propagation when first handler is pass. ### Registering bound filters @@ -44,7 +44,7 @@ Bound filter should be subclass of [BaseFilter](filters/index.md) `#!python3 .bind_filter(MyFilter)` ### Registering handlers -Method: `EventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)` +Method: `TelegramEventObserver.register(callback, filter1, filter2, ..., bound_filter=value, ...)` In this method is added bound filters keywords interface. | Argument | Type | Description | @@ -52,3 +52,13 @@ In this method is added bound filters keywords interface. | `callback` | `#!python3 Callable[[Any], Awaitable[Any]]` | Event handler | | `*filters` | `#!python3 Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]], BaseFilter]` | Ordered filters set | | `**bound_filters` | `#!python3 Any` | Bound filters | + + +### Decorator-style registering event handler with filters + +Usage: +```python3 +@(filter1, filter2, ...) +async def handler(*args, **kwargs): + pass +``` diff --git a/tests/test_api/test_client/test_session/test_aiohttp_session.py b/tests/test_api/test_client/test_session/test_aiohttp_session.py index ac93b1a7..edde8057 100644 --- a/tests/test_api/test_client/test_session/test_aiohttp_session.py +++ b/tests/test_api/test_client/test_session/test_aiohttp_session.py @@ -1,4 +1,3 @@ -import copy from typing import AsyncContextManager import aiohttp @@ -123,11 +122,3 @@ class TestAiohttpSession: assert session == ctx mocked_close.awaited_once() mocked_create_session.awaited_once() - - @pytest.mark.asyncio - async def test_deepcopy(self): - # Session should be copied without aiohttp.ClientSession - async with AiohttpSession() as session: - cloned_session = copy.deepcopy(session) - assert cloned_session != session - assert cloned_session._session is None diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index 2f157850..a4029197 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -39,68 +39,38 @@ class MyFilter3(MyFilter1): class TestEventObserver: - @pytest.mark.parametrize( - "count,handler,filters", - ( - pytest.param(5, my_handler, []), - pytest.param(3, my_handler, [lambda event: True]), - pytest.param( - 2, - my_handler, - [lambda event: True, lambda event: False, lambda event: {"ok": True}], - ), - ), - ) - def test_register_filters(self, count, handler, filters): + @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) + def test_register_filters(self, count, handler): observer = EventObserver() for index in range(count): wrapped_handler = functools.partial(handler, index=index) - observer.register(wrapped_handler, *filters) + observer.register(wrapped_handler) registered_handler = observer.handlers[index] assert len(observer.handlers) == index + 1 assert isinstance(registered_handler, HandlerObject) assert registered_handler.callback == wrapped_handler - assert len(registered_handler.filters) == len(filters) + assert not registered_handler.filters - @pytest.mark.parametrize( - "count,handler,filters", - ( - pytest.param(5, my_handler, []), - pytest.param(3, my_handler, [lambda event: True]), - pytest.param( - 2, - my_handler, - [lambda event: True, lambda event: False, lambda event: {"ok": True}], - ), - ), - ) - def test_register_filters_via_decorator(self, count, handler, filters): + @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) + def test_register_filters_via_decorator(self, count, handler): observer = EventObserver() for index in range(count): wrapped_handler = functools.partial(handler, index=index) - observer(*filters)(wrapped_handler) + observer()(wrapped_handler) registered_handler = observer.handlers[index] assert len(observer.handlers) == index + 1 assert isinstance(registered_handler, HandlerObject) assert registered_handler.callback == wrapped_handler - assert len(registered_handler.filters) == len(filters) - - @pytest.mark.asyncio - async def test_trigger_rejected(self): - observer = EventObserver() - observer.register(my_handler, lambda event: False) - - results = [result async for result in observer.trigger(42)] - assert results == [] + assert not registered_handler.filters @pytest.mark.asyncio async def test_trigger_accepted_bool(self): observer = EventObserver() - observer.register(my_handler, lambda event: True) + observer.register(my_handler) results = [result async for result in observer.trigger(42)] assert results == [42] @@ -108,23 +78,12 @@ class TestEventObserver: @pytest.mark.asyncio async def test_trigger_with_skip(self): observer = EventObserver() - observer.register(skip_my_handler, lambda event: True) - observer.register(my_handler, lambda event: False) - observer.register(my_handler, lambda event: True) + observer.register(skip_my_handler) + observer.register(my_handler) + observer.register(my_handler) results = [result async for result in observer.trigger(42)] - assert results == [42] - - @pytest.mark.asyncio - async def test_trigger_right_context_in_handlers(self): - observer = EventObserver() - observer.register( - pipe_handler, lambda event: {"a": 1}, lambda event: False - ) # {"a": 1} should not be in result - observer.register(pipe_handler, lambda event: {"b": 2}) - - results = [result async for result in observer.trigger(42)] - assert results == [((42,), {"b": 2})] + assert results == [42, 42] class TestTelegramEventObserver: @@ -144,9 +103,9 @@ class TestTelegramEventObserver: assert MyFilter in event_observer.filters def test_resolve_filters_chain(self): - router1 = Router() - router2 = Router() - router3 = Router() + router1 = Router(use_builtin_filters=False) + router2 = Router(use_builtin_filters=False) + router3 = Router(use_builtin_filters=False) router1.include_router(router2) router2.include_router(router3) @@ -168,7 +127,7 @@ class TestTelegramEventObserver: assert MyFilter3 in filters_chain3 def test_resolve_filters(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) @@ -189,7 +148,7 @@ class TestTelegramEventObserver: assert observer.resolve_filters({"test": ...}) def test_register(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) @@ -214,7 +173,7 @@ class TestTelegramEventObserver: assert MyFilter1(test="PASS") in callbacks def test_register_decorator(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler @observer() @@ -226,7 +185,7 @@ class TestTelegramEventObserver: @pytest.mark.asyncio async def test_trigger(self): - router = Router() + router = Router(use_builtin_filters=False) observer = router.message_handler observer.bind_filter(MyFilter1) observer.register(my_handler, test="ok") @@ -241,3 +200,38 @@ class TestTelegramEventObserver: results = [result async for result in observer.trigger(message)] assert results == [message] + + @pytest.mark.parametrize( + "count,handler,filters", + ( + [5, my_handler, []], + [3, my_handler, [lambda event: True]], + [2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]], + ), + ) + def test_register_filters_via_decorator(self, count, handler, filters): + router = Router(use_builtin_filters=False) + observer = router.message_handler + + for index in range(count): + wrapped_handler = functools.partial(handler, index=index) + observer(*filters)(wrapped_handler) + registered_handler = observer.handlers[index] + + assert len(observer.handlers) == index + 1 + assert isinstance(registered_handler, HandlerObject) + assert registered_handler.callback == wrapped_handler + assert len(registered_handler.filters) == len(filters) + + # + @pytest.mark.asyncio + async def test_trigger_right_context_in_handlers(self): + router = Router(use_builtin_filters=False) + observer = router.message_handler + observer.register( + pipe_handler, lambda event: {"a": 1}, lambda event: False + ) # {"a": 1} should not be in result + observer.register(pipe_handler, lambda event: {"b": 2}) + + results = [result async for result in observer.trigger(42)] + assert results == [((42,), {"b": 2})] diff --git a/tests/test_dispatcher/test_filters/test_command.py b/tests/test_dispatcher/test_filters/test_command.py index 063bddaa..f312c235 100644 --- a/tests/test_dispatcher/test_filters/test_command.py +++ b/tests/test_dispatcher/test_filters/test_command.py @@ -11,6 +11,13 @@ from tests.mocked_bot import MockedBot class TestCommandFilter: + def test_convert_to_list(self): + cmd = Command(commands="start") + assert cmd.commands + assert isinstance(cmd.commands, list) + assert cmd.commands[0] == "start" + assert cmd == Command(commands=["start"]) + @pytest.mark.asyncio async def test_parse_command(self, bot: MockedBot): # TODO: parametrize diff --git a/tests/test_dispatcher/test_filters/test_content_types.py b/tests/test_dispatcher/test_filters/test_content_types.py index 8efa5084..f5252323 100644 --- a/tests/test_dispatcher/test_filters/test_content_types.py +++ b/tests/test_dispatcher/test_filters/test_content_types.py @@ -1,18 +1,37 @@ +from dataclasses import dataclass +from typing import cast + import pytest from pydantic import ValidationError +from aiogram.api.types import ContentType, Message from aiogram.dispatcher.filters import ContentTypesFilter +@dataclass +class MinimalMessage: + content_type: str + + class TestContentTypesFilter: - def test_validator_empty(self): + @pytest.mark.asyncio + async def test_validator_empty(self): filter_ = ContentTypesFilter() - assert filter_.content_types == ["text"] + assert not filter_.content_types + await filter_(cast(Message, MinimalMessage(ContentType.TEXT))) + assert filter_.content_types == [ContentType.TEXT] def test_validator_empty_list(self): filter_ = ContentTypesFilter(content_types=[]) assert filter_.content_types == ["text"] + def test_convert_to_list(self): + filter_ = ContentTypesFilter(content_types="text") + assert filter_.content_types + assert isinstance(filter_.content_types, list) + assert filter_.content_types[0] == "text" + assert filter_ == ContentTypesFilter(content_types=["text"]) + @pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]]) def test_validator_with_values(self, values): filter_ = ContentTypesFilter(content_types=values) @@ -22,3 +41,19 @@ class TestContentTypesFilter: def test_validator_with_bad_values(self, values): with pytest.raises(ValidationError): ContentTypesFilter(content_types=values) + + @pytest.mark.parametrize( + "values,content_type,result", + [ + [[], ContentType.TEXT, True], + [[ContentType.TEXT], ContentType.TEXT, True], + [[ContentType.PHOTO], ContentType.TEXT, False], + [[ContentType.ANY], ContentType.TEXT, True], + [[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True], + [[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True], + ], + ) + @pytest.mark.asyncio + async def test_call(self, values, content_type, result): + filter_ = ContentTypesFilter(content_types=values) + assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result diff --git a/tests/test_dispatcher/test_filters/test_text.py b/tests/test_dispatcher/test_filters/test_text.py index df1d26c1..a5f4daaf 100644 --- a/tests/test_dispatcher/test_filters/test_text.py +++ b/tests/test_dispatcher/test_filters/test_text.py @@ -1,6 +1,6 @@ import datetime from itertools import permutations -from typing import Type +from typing import Sequence, Type import pytest from pydantic import ValidationError @@ -46,11 +46,11 @@ class TestText: @pytest.mark.parametrize( "argument", ["text", "text_contains", "text_startswith", "text_endswith"] ) - @pytest.mark.parametrize("input_type", [str, list, tuple, set]) + @pytest.mark.parametrize("input_type", [str, list, tuple]) def test_validator_convert_to_list(self, argument: str, input_type: Type): text = Text(**{argument: input_type("test")}) assert hasattr(text, argument) - assert isinstance(getattr(text, argument), list) + assert isinstance(getattr(text, argument), Sequence) @pytest.mark.parametrize( "argument,ignore_case,input_value,update_type,result",