mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Improved filters.
This commit is contained in:
parent
cd4fee5eaa
commit
24184b1c8f
3 changed files with 155 additions and 20 deletions
|
|
@ -1,20 +1,23 @@
|
||||||
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
|
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
|
||||||
RegexpFilter, StateFilter
|
StateFilter, Text
|
||||||
from .factory import FiltersFactory
|
from .factory import FiltersFactory
|
||||||
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AbstractFilter',
|
'AbstractFilter',
|
||||||
'BaseFilter',
|
'BaseFilter',
|
||||||
|
'Command',
|
||||||
'CommandsFilter',
|
'CommandsFilter',
|
||||||
'ContentTypeFilter',
|
'ContentTypeFilter',
|
||||||
'ExceptionsFilter',
|
'ExceptionsFilter',
|
||||||
|
'Filter',
|
||||||
|
'FilterNotPassed',
|
||||||
'FilterRecord',
|
'FilterRecord',
|
||||||
'FiltersFactory',
|
'FiltersFactory',
|
||||||
'RegexpCommandsFilter',
|
'RegexpCommandsFilter',
|
||||||
'RegexpFilter',
|
'RegexpFilter',
|
||||||
'StateFilter',
|
'StateFilter',
|
||||||
|
'Text',
|
||||||
'check_filter',
|
'check_filter',
|
||||||
'check_filters',
|
'check_filters'
|
||||||
'FilterNotPassed'
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,62 @@
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from _contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from aiogram.dispatcher.filters.filters import BaseFilter
|
from aiogram.dispatcher.filters.filters import BaseFilter, Filter
|
||||||
from aiogram.types import CallbackQuery, ContentType, Message
|
from aiogram.types import CallbackQuery, ContentType, Message
|
||||||
|
|
||||||
|
|
||||||
|
class Command(Filter):
|
||||||
|
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
|
||||||
|
if isinstance(commands, str):
|
||||||
|
commands = (commands,)
|
||||||
|
|
||||||
|
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
||||||
|
self.prefixes = prefixes
|
||||||
|
self.ignore_case = ignore_case
|
||||||
|
self.ignore_mention = ignore_mention
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||||
|
full_command = message.text.split()[0]
|
||||||
|
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
||||||
|
|
||||||
|
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
||||||
|
return False
|
||||||
|
elif prefix not in prefixes:
|
||||||
|
return False
|
||||||
|
elif (command.lower() if ignore_case else command) not in commands:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||||
|
|
||||||
|
async def check(self, message: types.Message):
|
||||||
|
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandObj:
|
||||||
|
prefix: str = '/'
|
||||||
|
command: str = ''
|
||||||
|
mention: str = None
|
||||||
|
args: str = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mentioned(self) -> bool:
|
||||||
|
return bool(self.mention)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
line = self.prefix + self.command
|
||||||
|
if self.mentioned:
|
||||||
|
line += '@' + self.mention
|
||||||
|
if self.args:
|
||||||
|
line += ' ' + self.args
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(BaseFilter):
|
class CommandsFilter(BaseFilter):
|
||||||
"""
|
"""
|
||||||
Check commands in message
|
Check commands in message
|
||||||
|
|
@ -15,23 +65,53 @@ class CommandsFilter(BaseFilter):
|
||||||
|
|
||||||
def __init__(self, dispatcher, commands):
|
def __init__(self, dispatcher, commands):
|
||||||
super().__init__(dispatcher)
|
super().__init__(dispatcher)
|
||||||
|
if isinstance(commands, str):
|
||||||
|
commands = (commands,)
|
||||||
self.commands = commands
|
self.commands = commands
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
if not message.is_command():
|
return await Command.check_command(message, self.commands, '/')
|
||||||
|
|
||||||
|
|
||||||
|
class Text(Filter):
|
||||||
|
def __init__(self,
|
||||||
|
equals: Optional[str] = None,
|
||||||
|
contains: Optional[str] = None,
|
||||||
|
startswith: Optional[str] = None,
|
||||||
|
endswith: Optional[str] = None,
|
||||||
|
ignore_case=False):
|
||||||
|
# Only one mode can be used. check it.
|
||||||
|
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||||
|
if check > 1:
|
||||||
|
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
||||||
|
('contains', contains),
|
||||||
|
('startswith', startswith),
|
||||||
|
('endswith', endswith)
|
||||||
|
] if arg[1]])
|
||||||
|
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||||
|
elif check == 0:
|
||||||
|
raise ValueError(f"No one mode is specified!")
|
||||||
|
|
||||||
|
self.equals = equals
|
||||||
|
self.contains = contains
|
||||||
|
self.endswith = endswith
|
||||||
|
self.startswith = startswith
|
||||||
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
|
async def check(self, message: types.Message):
|
||||||
|
text = message.text.lower() if self.ignore_case else message.text
|
||||||
|
|
||||||
|
if self.equals:
|
||||||
|
return text == self.equals
|
||||||
|
elif self.contains:
|
||||||
|
return self.contains in text
|
||||||
|
elif self.startswith:
|
||||||
|
return text.startswith(self.startswith)
|
||||||
|
elif self.endswith:
|
||||||
|
return text.endswith(self.endswith)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
command = message.text.split()[0][1:]
|
|
||||||
command, _, mention = command.partition('@')
|
|
||||||
|
|
||||||
if mention and mention != (await message.bot.me).username:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if command not in self.commands:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class RegexpFilter(BaseFilter):
|
class RegexpFilter(BaseFilter):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,17 @@ class FilterNotPassed(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_async(func):
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if inspect.isawaitable(func) \
|
||||||
|
or inspect.iscoroutinefunction(func) \
|
||||||
|
or isinstance(func, AbstractFilter):
|
||||||
|
return func
|
||||||
|
return async_wrapper
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args):
|
async def check_filter(filter_, args):
|
||||||
"""
|
"""
|
||||||
Helper for executing filter
|
Helper for executing filter
|
||||||
|
|
@ -99,10 +110,6 @@ class AbstractFilter(abc.ABC):
|
||||||
|
|
||||||
key = None
|
key = None
|
||||||
|
|
||||||
def __init__(self, dispatcher, **config):
|
|
||||||
self.dispatcher = dispatcher
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
|
@ -127,6 +134,15 @@ class AbstractFilter(abc.ABC):
|
||||||
async def __call__(self, obj: TelegramObject) -> bool:
|
async def __call__(self, obj: TelegramObject) -> bool:
|
||||||
return await self.check(obj)
|
return await self.check(obj)
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return NotFilter(self)
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
return AndFilter(self, other)
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
return OrFilter(self, other)
|
||||||
|
|
||||||
|
|
||||||
class BaseFilter(AbstractFilter):
|
class BaseFilter(AbstractFilter):
|
||||||
"""
|
"""
|
||||||
|
|
@ -136,6 +152,10 @@ class BaseFilter(AbstractFilter):
|
||||||
required = False
|
required = False
|
||||||
default = None
|
default = None
|
||||||
|
|
||||||
|
def __init__(self, dispatcher, **config):
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
self.config = config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||||
if cls.key is not None:
|
if cls.key is not None:
|
||||||
|
|
@ -143,3 +163,35 @@ class BaseFilter(AbstractFilter):
|
||||||
return {cls.key: full_config[cls.key]}
|
return {cls.key: full_config[cls.key]}
|
||||||
elif cls.required:
|
elif cls.required:
|
||||||
return {cls.key: cls.default}
|
return {cls.key: cls.default}
|
||||||
|
|
||||||
|
|
||||||
|
class Filter(AbstractFilter):
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
raise RuntimeError('This filter can\'t be passed as kwargs')
|
||||||
|
|
||||||
|
|
||||||
|
class NotFilter(Filter):
|
||||||
|
def __init__(self, target):
|
||||||
|
self.target = wrap_async(target)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
return await self.target(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class AndFilter(Filter):
|
||||||
|
def __init__(self, target, target2):
|
||||||
|
self.target = wrap_async(target)
|
||||||
|
self.target2 = wrap_async(target2)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
return (await self.target(*args)) and (await self.target2(*args))
|
||||||
|
|
||||||
|
|
||||||
|
class OrFilter(Filter):
|
||||||
|
def __init__(self, target, target2):
|
||||||
|
self.target = wrap_async(target)
|
||||||
|
self.target2 = wrap_async(target2)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
return (await self.target(*args)) or (await self.target2(*args))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue