diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 373dafe5..e109eb6d 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -10,7 +10,7 @@ from babel.support import LazyProxy from aiogram import types from aiogram.dispatcher.filters.filters import BoundFilter, Filter -from aiogram.types import CallbackQuery, ChatType, InlineQuery, Message, Poll, ChatMemberUpdated +from aiogram.types import CallbackQuery, ChatType, InlineQuery, Message, Poll, ChatMemberUpdated, BotCommand ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str, int] @@ -34,7 +34,7 @@ class Command(Filter): By default this filter is registered for messages and edited messages handlers. """ - def __init__(self, commands: Union[Iterable, str], + def __init__(self, commands: Union[Iterable[Union[str, BotCommand]], str, BotCommand], prefixes: Union[Iterable, str] = '/', ignore_case: bool = True, ignore_mention: bool = False, @@ -66,8 +66,19 @@ class Command(Filter): @dp.message_handler(commands=['myCommand'], commands_ignore_caption=False, content_types=ContentType.ANY) @dp.message_handler(Command(['myCommand'], ignore_caption=False), content_types=[ContentType.TEXT, ContentType.DOCUMENT]) """ - if isinstance(commands, str): + if isinstance(commands, (str, BotCommand)): commands = (commands,) + elif isinstance(commands, Iterable): + if not all(isinstance(cmd, (str, BotCommand)) for cmd in commands): + raise ValueError( + "Command filter only supports str, BotCommand object or their Iterable" + ) + else: + raise ValueError( + "Command filter doesn't support {} as input. " + "It only supports str, BotCommand object or their Iterable".format(type(commands)) + ) + commands = [cmd.command if isinstance(cmd, BotCommand) else cmd for cmd in commands] self.commands = list(map(str.lower, commands)) if ignore_case else commands self.prefixes = prefixes diff --git a/tests/test_dispatcher/test_filters/test_builtin.py b/tests/test_dispatcher/test_filters/test_builtin.py index 4f05cb22..35e6cb0a 100644 --- a/tests/test_dispatcher/test_filters/test_builtin.py +++ b/tests/test_dispatcher/test_filters/test_builtin.py @@ -1,4 +1,4 @@ -from typing import Set +from typing import Set, Union, Iterable from datetime import datetime import pytest @@ -6,9 +6,9 @@ import pytest from aiogram.dispatcher.filters.builtin import ( Text, extract_chat_ids, - ChatIDArgumentType, ForwardedMessageFilter, IDFilter, + ChatIDArgumentType, ForwardedMessageFilter, IDFilter, Command, ) -from aiogram.types import Message +from aiogram.types import Message, BotCommand from tests.types.dataset import MESSAGE, MESSAGE_FROM_CHANNEL @@ -108,3 +108,42 @@ class TestIDFilter: filter = IDFilter(chat_id=message_from_channel.chat.id) assert await filter.check(message_from_channel) + + +@pytest.mark.parametrize("command", [ + "/start", + "/start some args", +]) +@pytest.mark.parametrize("cmd_filter", [ + "start", + ("start",), + BotCommand(command="start", description="my desc"), + (BotCommand(command="start", description="bar"),), + (BotCommand(command="start", description="foo"), "help"), +]) +@pytest.mark.asyncio +async def test_commands_filter(command: str, cmd_filter: Union[Iterable[Union[str, BotCommand]], str, BotCommand]): + message_with_command = Message(**MESSAGE) + message_with_command.text = command + + start_filter = Command(commands=cmd_filter) + + assert await start_filter.check(message_with_command) + + +@pytest.mark.asyncio +async def test_commands_filter_not_checked(): + message_with_command = Message(**MESSAGE) + message_with_command.text = "/start" + + start_filter = Command(commands=["help", BotCommand("about", "my desc")]) + + assert not await start_filter.check(message_with_command) + + +def test_commands_filter_raises_error(): + with pytest.raises(ValueError): + start_filter = Command(commands=42) # noqa + with pytest.raises(ValueError): + start_filter = Command(commands=[42]) # noqa +