diff --git a/aiogram/dispatcher/filters/command.py b/aiogram/dispatcher/filters/command.py index c97f79fa..6eef0c3c 100644 --- a/aiogram/dispatcher/filters/command.py +++ b/aiogram/dispatcher/filters/command.py @@ -2,7 +2,9 @@ from __future__ import annotations import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Match, Optional, Pattern, Union +from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast + +from pydantic import validator from aiogram import Bot from aiogram.api.types import Message @@ -12,11 +14,19 @@ CommandPatterType = Union[str, re.Pattern] # type: ignore class Command(BaseFilter): - commands: List[CommandPatterType] + commands: Union[Sequence[CommandPatterType], CommandPatterType] commands_prefix: str = "/" commands_ignore_case: bool = False commands_ignore_mention: bool = False + @validator("commands", always=True) + def _validate_commands( + cls, value: Union[Sequence[CommandPatterType], CommandPatterType] + ) -> Sequence[CommandPatterType]: + if isinstance(value, (str, re.Pattern)): + value = [value] + return value + async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: if not message.text: return False @@ -50,7 +60,7 @@ class Command(BaseFilter): return False # Validate command - for allowed_command in self.commands: + for allowed_command in cast(Sequence[CommandPatterType], self.commands): # Command can be presented as regexp pattern or raw string # then need to validate that in different ways if isinstance(allowed_command, Pattern): # Regexp diff --git a/aiogram/dispatcher/filters/content_types.py b/aiogram/dispatcher/filters/content_types.py index f755f8e9..1f195663 100644 --- a/aiogram/dispatcher/filters/content_types.py +++ b/aiogram/dispatcher/filters/content_types.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union from pydantic import validator @@ -8,12 +8,16 @@ from .base import BaseFilter class ContentTypesFilter(BaseFilter): - content_types: Optional[List[str]] = None + content_types: Optional[Union[Sequence[str], str]] = None @validator("content_types") - def _validate_content_types(cls, value: Optional[List[str]]) -> Optional[List[str]]: + def _validate_content_types( + cls, value: Optional[Union[Sequence[str], str]] + ) -> Optional[Sequence[str]]: if not value: value = [ContentType.TEXT] + if isinstance(value, str): + value = [value] allowed_content_types = set(ContentType.all()) bad_content_types = set(value) - allowed_content_types if bad_content_types: diff --git a/aiogram/dispatcher/filters/text.py b/aiogram/dispatcher/filters/text.py index 71586e7e..8a94cd7b 100644 --- a/aiogram/dispatcher/filters/text.py +++ b/aiogram/dispatcher/filters/text.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Union from pydantic import root_validator from aiogram.api.types import CallbackQuery, InlineQuery, Message, Poll from aiogram.dispatcher.filters import BaseFilter +TextType = str + class Text(BaseFilter): """ @@ -12,14 +14,14 @@ class Text(BaseFilter): InlineQuery or Poll question. """ - text: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_contains: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_startswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None - text_endswith: Optional[Union[str, List[str], Set[str], Tuple[str]]] = None + text: Optional[Union[Sequence[TextType], TextType]] = None + text_contains: Optional[Union[Sequence[TextType], TextType]] = None + text_startswith: Optional[Union[Sequence[TextType], TextType]] = None + text_endswith: Optional[Union[Sequence[TextType], TextType]] = None text_ignore_case: bool = False @root_validator - def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: # Validate that only one text filter type is presented used_args = set( key for key, value in values.items() if key != "text_ignore_case" and value is not None diff --git a/docs/dispatcher/filters/command.md b/docs/dispatcher/filters/command.md index 5ce340c9..5d8ab771 100644 --- a/docs/dispatcher/filters/command.md +++ b/docs/dispatcher/filters/command.md @@ -7,7 +7,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex ## Specification | Argument | Type | Description | | --- | --- | --- | -| `commands` | `#!python3 List[CommandPatterType]` | List of commands (string or compiled regexp patterns) | +| `commands` | `#!python3 Union[Sequence[Union[str, re.Pattern]], Union[str, re.Pattern]]` | List of commands (string or compiled regexp patterns) | | `commands_prefix` | `#!python3 str` | Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, for example: `"/!"` will work with commands prefixed by `"/"` or `"!"` (Default: `"/"`). | | `commands_ignore_case` | `#!python3 bool` | Ignore case (Does not work with regexp, use flags instead. Default: `False`) | | `commands_ignore_mention` | `#!python3 bool` | Ignore bot mention. By default bot can not handle commands intended for other bots (Default: `False`) | @@ -15,7 +15,7 @@ Works only with [Message](../../api/types/message.md) events which have the `tex ## Usage -1. Filter single variant of commands: `#!python3 Command(commands=["start"])` +1. Filter single variant of commands: `#!python3 Command(commands=["start"])` or `#!python3 Command(commands="start")` 1. Handle command by regexp pattern: `#!python3 Command(commands=[re.compile(r"item_(\d+)")])` 1. Match command by multiple variants: `#!python3 Command(commands=["item", re.compile(r"item_(\d+)")])` 1. Handle commands in public chats intended for other bots: `#!python3 Command(commands=["command"], commands)` diff --git a/docs/dispatcher/filters/content_types.md b/docs/dispatcher/filters/content_types.md index 30a635b1..a37006e3 100644 --- a/docs/dispatcher/filters/content_types.md +++ b/docs/dispatcher/filters/content_types.md @@ -19,11 +19,11 @@ Or used from filters factory by passing corresponding arguments to handler regis | Argument | Type | Description | | --- | --- | --- | -| `content_types` | `#!python3 Optional[List[str]]` | List of allowed content types | +| `content_types` | `#!python3 Optional[Union[Sequence[str], str]]` | List of allowed content types | ## Usage -1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` +1. Single content type: `#!python3 ContentTypesFilter(content_types=["sticker"])` or `#!python3 ContentTypesFilter(content_types="sticker")` 1. Multiple content types: `#!python3 ContentTypesFilter(content_types=["sticker", "photo"])` 1. Recommended: With usage of `ContentType` helper: `#!python3 ContentTypesFilter(content_types=[ContentType.PHOTO])` 1. Any content type: `#!python3 ContentTypesFilter(content_types=[ContentType.ANY])` diff --git a/docs/dispatcher/filters/text.md b/docs/dispatcher/filters/text.md index 7c7338aa..0691459b 100644 --- a/docs/dispatcher/filters/text.md +++ b/docs/dispatcher/filters/text.md @@ -18,10 +18,10 @@ Or used from filters factory by passing corresponding arguments to handler regis | Argument | Type | Description | | --- | --- | --- | -| `text` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text equals value or one of values | -| `text_contains` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text contains value or one of values | -| `text_startswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text starts with value or one of values | -| `text_endswith` | `#!python3 Optional[Union[str, List[str], Set[str], Tuple[str]]]` | Text ends with value or one of values | +| `text` | `#!python3 Optional[Union[Sequence[str], str]]` | Text equals value or one of values | +| `text_contains` | `#!python3 Optional[Union[Sequence[str], str]]` | Text contains value or one of values | +| `text_startswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text starts with value or one of values | +| `text_endswith` | `#!python3 Optional[Union[Sequence[str], str]]` | Text ends with value or one of values | | `text_ignore_case` | `#!python3 bool` | Ignore case when checks (Default: `#!python3 False`) | !!! warning diff --git a/tests/test_dispatcher/test_filters/test_command.py b/tests/test_dispatcher/test_filters/test_command.py index 063bddaa..f312c235 100644 --- a/tests/test_dispatcher/test_filters/test_command.py +++ b/tests/test_dispatcher/test_filters/test_command.py @@ -11,6 +11,13 @@ from tests.mocked_bot import MockedBot class TestCommandFilter: + def test_convert_to_list(self): + cmd = Command(commands="start") + assert cmd.commands + assert isinstance(cmd.commands, list) + assert cmd.commands[0] == "start" + assert cmd == Command(commands=["start"]) + @pytest.mark.asyncio async def test_parse_command(self, bot: MockedBot): # TODO: parametrize diff --git a/tests/test_dispatcher/test_filters/test_content_types.py b/tests/test_dispatcher/test_filters/test_content_types.py index 9c83cc09..f5252323 100644 --- a/tests/test_dispatcher/test_filters/test_content_types.py +++ b/tests/test_dispatcher/test_filters/test_content_types.py @@ -25,6 +25,13 @@ class TestContentTypesFilter: filter_ = ContentTypesFilter(content_types=[]) assert filter_.content_types == ["text"] + def test_convert_to_list(self): + filter_ = ContentTypesFilter(content_types="text") + assert filter_.content_types + assert isinstance(filter_.content_types, list) + assert filter_.content_types[0] == "text" + assert filter_ == ContentTypesFilter(content_types=["text"]) + @pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]]) def test_validator_with_values(self, values): filter_ = ContentTypesFilter(content_types=values) diff --git a/tests/test_dispatcher/test_filters/test_text.py b/tests/test_dispatcher/test_filters/test_text.py index df1d26c1..a5f4daaf 100644 --- a/tests/test_dispatcher/test_filters/test_text.py +++ b/tests/test_dispatcher/test_filters/test_text.py @@ -1,6 +1,6 @@ import datetime from itertools import permutations -from typing import Type +from typing import Sequence, Type import pytest from pydantic import ValidationError @@ -46,11 +46,11 @@ class TestText: @pytest.mark.parametrize( "argument", ["text", "text_contains", "text_startswith", "text_endswith"] ) - @pytest.mark.parametrize("input_type", [str, list, tuple, set]) + @pytest.mark.parametrize("input_type", [str, list, tuple]) def test_validator_convert_to_list(self, argument: str, input_type: Type): text = Text(**{argument: input_type("test")}) assert hasattr(text, argument) - assert isinstance(getattr(text, argument), list) + assert isinstance(getattr(text, argument), Sequence) @pytest.mark.parametrize( "argument,ignore_case,input_value,update_type,result",