mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Refactor command filter and fix typing
This commit is contained in:
parent
6cb1f96ded
commit
1cd993009e
4 changed files with 43 additions and 24 deletions
|
|
@ -2,15 +2,13 @@ from __future__ import annotations
|
|||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AnyStr, Dict, List, Match, Optional, Pattern, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
from typing import Any, Dict, List, Match, Optional, Pattern, Union
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.api.types import Message
|
||||
from aiogram.dispatcher.filters import BaseFilter
|
||||
|
||||
CommandPatterType = Union[str, re.Pattern]
|
||||
CommandPatterType = Union[str, re.Pattern] # type: ignore
|
||||
|
||||
|
||||
class Command(BaseFilter):
|
||||
|
|
@ -19,14 +17,6 @@ class Command(BaseFilter):
|
|||
commands_ignore_case: bool = False
|
||||
commands_ignore_mention: bool = False
|
||||
|
||||
@root_validator
|
||||
def validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "commands" not in values:
|
||||
raise ValueError("Commands required")
|
||||
if not isinstance(values["commands"], list):
|
||||
values["commands"] = [values["commands"]]
|
||||
return values
|
||||
|
||||
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
||||
if not message.text:
|
||||
return False
|
||||
|
|
@ -54,12 +44,10 @@ class Command(BaseFilter):
|
|||
return False
|
||||
|
||||
# Validate mention
|
||||
if (
|
||||
mention
|
||||
and not self.commands_ignore_mention
|
||||
and mention.lower() != (await bot.me()).username.lower()
|
||||
):
|
||||
return False
|
||||
if mention and not self.commands_ignore_mention:
|
||||
me = await bot.me()
|
||||
if me.username and mention.lower() != me.username.lower():
|
||||
return False
|
||||
|
||||
# Validate command
|
||||
for allowed_command in self.commands:
|
||||
|
|
@ -106,17 +94,18 @@ class CommandObject:
|
|||
"""Command prefix"""
|
||||
command: str = ""
|
||||
"""Command without prefix and mention"""
|
||||
mention: str = None
|
||||
mention: Optional[str] = None
|
||||
"""Mention (if available)"""
|
||||
args: str = field(repr=False, default=None)
|
||||
args: Optional[str] = field(repr=False, default=None)
|
||||
"""Command argument"""
|
||||
match: Optional[Match[AnyStr]] = None
|
||||
match: Optional[Match[str]] = field(repr=False, default=None)
|
||||
"""Will be presented match result if the command is presented as regexp in filter"""
|
||||
|
||||
@property
|
||||
def mentioned(self) -> bool:
|
||||
"""
|
||||
This command has mention?
|
||||
|
||||
:return:
|
||||
"""
|
||||
return bool(self.mention)
|
||||
|
|
@ -125,10 +114,11 @@ class CommandObject:
|
|||
def text(self) -> str:
|
||||
"""
|
||||
Generate original text from object
|
||||
|
||||
:return:
|
||||
"""
|
||||
line = self.prefix + self.command
|
||||
if self.mentioned:
|
||||
if self.mention:
|
||||
line += "@" + self.mention
|
||||
if self.args:
|
||||
line += " " + self.args
|
||||
|
|
|
|||
|
|
@ -22,5 +22,5 @@ class MessageHandlerCommandMixin(BaseHandlerMixin):
|
|||
@property
|
||||
def command(self) -> Optional[CommandObject]:
|
||||
if "command" in self.data:
|
||||
return self.command["data"]
|
||||
return self.data["command"]
|
||||
return None
|
||||
|
|
|
|||
6
tests/test_dispatcher/test_filters/test_command.py
Normal file
6
tests/test_dispatcher/test_filters/test_command.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from aiogram.dispatcher.filters import Command
|
||||
|
||||
|
||||
class TestCommandFilter:
|
||||
def test_validator(self):
|
||||
pass
|
||||
|
|
@ -4,7 +4,8 @@ from typing import Any
|
|||
import pytest
|
||||
|
||||
from aiogram.api.types import Chat, Message, User
|
||||
from aiogram.dispatcher.handler.message import MessageHandler
|
||||
from aiogram.dispatcher.filters import CommandObject
|
||||
from aiogram.dispatcher.handler.message import MessageHandler, MessageHandlerCommandMixin
|
||||
|
||||
|
||||
class MyHandler(MessageHandler):
|
||||
|
|
@ -26,3 +27,25 @@ class TestClassBasedMessageHandler:
|
|||
|
||||
assert handler.from_user == event.from_user
|
||||
assert handler.chat == event.chat
|
||||
|
||||
|
||||
class HandlerWithCommand(MessageHandlerCommandMixin, MessageHandler):
|
||||
async def handle(self) -> Any:
|
||||
return self.command
|
||||
|
||||
|
||||
class TestBaseMessageHandlerCommandMixin:
|
||||
def test_command_accessible(self):
|
||||
handler = HandlerWithCommand(
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
text="test",
|
||||
chat=Chat(id=42, type="private"),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
command=CommandObject(prefix="/", command="command", args="args"),
|
||||
)
|
||||
|
||||
assert isinstance(handler.command, CommandObject)
|
||||
assert handler.command.command == "command"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue