mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Improved filters.
This commit is contained in:
parent
cd4fee5eaa
commit
24184b1c8f
3 changed files with 155 additions and 20 deletions
|
|
@ -1,20 +1,23 @@
|
|||
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
|
||||
RegexpFilter, StateFilter
|
||||
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
|
||||
StateFilter, Text
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
'BaseFilter',
|
||||
'Command',
|
||||
'CommandsFilter',
|
||||
'ContentTypeFilter',
|
||||
'ExceptionsFilter',
|
||||
'Filter',
|
||||
'FilterNotPassed',
|
||||
'FilterRecord',
|
||||
'FiltersFactory',
|
||||
'RegexpCommandsFilter',
|
||||
'RegexpFilter',
|
||||
'StateFilter',
|
||||
'Text',
|
||||
'check_filter',
|
||||
'check_filters',
|
||||
'FilterNotPassed'
|
||||
'check_filters'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,62 @@
|
|||
import inspect
|
||||
import re
|
||||
from _contextvars import ContextVar
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import types
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter
|
||||
from aiogram.types import CallbackQuery, ContentType, Message
|
||||
|
||||
|
||||
class Command(Filter):
|
||||
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
|
||||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
|
||||
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
||||
self.prefixes = prefixes
|
||||
self.ignore_case = ignore_case
|
||||
self.ignore_mention = ignore_mention
|
||||
|
||||
@staticmethod
|
||||
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||
full_command = message.text.split()[0]
|
||||
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
||||
|
||||
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
||||
return False
|
||||
elif prefix not in prefixes:
|
||||
return False
|
||||
elif (command.lower() if ignore_case else command) not in commands:
|
||||
return False
|
||||
|
||||
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||
|
||||
@dataclass
|
||||
class CommandObj:
|
||||
prefix: str = '/'
|
||||
command: str = ''
|
||||
mention: str = None
|
||||
args: str = None
|
||||
|
||||
@property
|
||||
def mentioned(self) -> bool:
|
||||
return bool(self.mention)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
line = self.prefix + self.command
|
||||
if self.mentioned:
|
||||
line += '@' + self.mention
|
||||
if self.args:
|
||||
line += ' ' + self.args
|
||||
return line
|
||||
|
||||
|
||||
class CommandsFilter(BaseFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
|
|
@ -15,22 +65,52 @@ class CommandsFilter(BaseFilter):
|
|||
|
||||
def __init__(self, dispatcher, commands):
|
||||
super().__init__(dispatcher)
|
||||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
self.commands = commands
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
return await Command.check_command(message, self.commands, '/')
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
class Text(Filter):
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[str] = None,
|
||||
startswith: Optional[str] = None,
|
||||
endswith: Optional[str] = None,
|
||||
ignore_case=False):
|
||||
# Only one mode can be used. check it.
|
||||
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||
if check > 1:
|
||||
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
||||
('contains', contains),
|
||||
('startswith', startswith),
|
||||
('endswith', endswith)
|
||||
] if arg[1]])
|
||||
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||
elif check == 0:
|
||||
raise ValueError(f"No one mode is specified!")
|
||||
|
||||
if command not in self.commands:
|
||||
return False
|
||||
self.equals = equals
|
||||
self.contains = contains
|
||||
self.endswith = endswith
|
||||
self.startswith = startswith
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
return True
|
||||
async def check(self, message: types.Message):
|
||||
text = message.text.lower() if self.ignore_case else message.text
|
||||
|
||||
if self.equals:
|
||||
return text == self.equals
|
||||
elif self.contains:
|
||||
return self.contains in text
|
||||
elif self.startswith:
|
||||
return text.startswith(self.startswith)
|
||||
elif self.endswith:
|
||||
return text.endswith(self.endswith)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RegexpFilter(BaseFilter):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,17 @@ class FilterNotPassed(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def wrap_async(func):
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isawaitable(func) \
|
||||
or inspect.iscoroutinefunction(func) \
|
||||
or isinstance(func, AbstractFilter):
|
||||
return func
|
||||
return async_wrapper
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
|
@ -99,10 +110,6 @@ class AbstractFilter(abc.ABC):
|
|||
|
||||
key = None
|
||||
|
||||
def __init__(self, dispatcher, **config):
|
||||
self.dispatcher = dispatcher
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
|
|
@ -127,6 +134,15 @@ class AbstractFilter(abc.ABC):
|
|||
async def __call__(self, obj: TelegramObject) -> bool:
|
||||
return await self.check(obj)
|
||||
|
||||
def __invert__(self):
|
||||
return NotFilter(self)
|
||||
|
||||
def __and__(self, other):
|
||||
return AndFilter(self, other)
|
||||
|
||||
def __or__(self, other):
|
||||
return OrFilter(self, other)
|
||||
|
||||
|
||||
class BaseFilter(AbstractFilter):
|
||||
"""
|
||||
|
|
@ -136,6 +152,10 @@ class BaseFilter(AbstractFilter):
|
|||
required = False
|
||||
default = None
|
||||
|
||||
def __init__(self, dispatcher, **config):
|
||||
self.dispatcher = dispatcher
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||
if cls.key is not None:
|
||||
|
|
@ -143,3 +163,35 @@ class BaseFilter(AbstractFilter):
|
|||
return {cls.key: full_config[cls.key]}
|
||||
elif cls.required:
|
||||
return {cls.key: cls.default}
|
||||
|
||||
|
||||
class Filter(AbstractFilter):
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
raise RuntimeError('This filter can\'t be passed as kwargs')
|
||||
|
||||
|
||||
class NotFilter(Filter):
|
||||
def __init__(self, target):
|
||||
self.target = wrap_async(target)
|
||||
|
||||
async def check(self, *args):
|
||||
return await self.target(*args)
|
||||
|
||||
|
||||
class AndFilter(Filter):
|
||||
def __init__(self, target, target2):
|
||||
self.target = wrap_async(target)
|
||||
self.target2 = wrap_async(target2)
|
||||
|
||||
async def check(self, *args):
|
||||
return (await self.target(*args)) and (await self.target2(*args))
|
||||
|
||||
|
||||
class OrFilter(Filter):
|
||||
def __init__(self, target, target2):
|
||||
self.target = wrap_async(target)
|
||||
self.target2 = wrap_async(target2)
|
||||
|
||||
async def check(self, *args):
|
||||
return (await self.target(*args)) or (await self.target2(*args))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue