mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 02:03:04 +00:00
Refactor filters.
This commit is contained in:
parent
539c76a062
commit
f957883082
5 changed files with 191 additions and 80 deletions
|
|
@ -6,8 +6,9 @@ import time
|
|||
import typing
|
||||
from contextvars import ContextVar
|
||||
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
|
||||
RegexpFilter, StateFilter
|
||||
from aiogram.dispatcher.filters import Command
|
||||
from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
|
||||
Regexp, StateFilter
|
||||
from .handler import Handler
|
||||
from .middlewares import MiddlewareManager
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||
|
|
@ -84,10 +85,10 @@ class Dispatcher:
|
|||
self.message_handlers, self.edited_message_handlers,
|
||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||
])
|
||||
filters_factory.bind(CommandsFilter, event_handlers=[
|
||||
filters_factory.bind(Command, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers
|
||||
])
|
||||
filters_factory.bind(RegexpFilter, event_handlers=[
|
||||
filters_factory.bind(Regexp, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers,
|
||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||
self.callback_query_handlers
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
|
||||
StateFilter, Text
|
||||
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
|
||||
RegexpCommandsFilter, StateFilter, Text
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
'BaseFilter',
|
||||
'BoundFilter',
|
||||
'Command',
|
||||
'CommandsFilter',
|
||||
'CommandStart',
|
||||
'CommandHelp',
|
||||
'ContentTypeFilter',
|
||||
'ExceptionsFilter',
|
||||
'Filter',
|
||||
|
|
@ -15,7 +16,7 @@ __all__ = [
|
|||
'FilterRecord',
|
||||
'FiltersFactory',
|
||||
'RegexpCommandsFilter',
|
||||
'RegexpFilter',
|
||||
'Regexp',
|
||||
'StateFilter',
|
||||
'Text',
|
||||
'check_filter',
|
||||
|
|
|
|||
|
|
@ -2,15 +2,30 @@ import inspect
|
|||
import re
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
from aiogram import types
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter
|
||||
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
|
||||
from aiogram.types import CallbackQuery, ContentType, Message
|
||||
|
||||
|
||||
class Command(Filter):
|
||||
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
|
||||
"""
|
||||
You can handle commands by using this filter
|
||||
"""
|
||||
|
||||
def __init__(self, commands: Union[Iterable, str],
|
||||
prefixes: Union[Iterable, str] = '/',
|
||||
ignore_case: bool = True,
|
||||
ignore_mention: bool = False):
|
||||
"""
|
||||
Filter can be initialized from filters factory or by simply creating instance of this class
|
||||
|
||||
:param commands: command or list of commands
|
||||
:param prefixes:
|
||||
:param ignore_case:
|
||||
:param ignore_mention:
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
|
||||
|
|
@ -19,6 +34,26 @@ class Command(Filter):
|
|||
self.ignore_case = ignore_case
|
||||
self.ignore_mention = ignore_mention
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validator for filters factory
|
||||
|
||||
:param full_config:
|
||||
:return: config or empty dict
|
||||
"""
|
||||
config = {}
|
||||
if 'commands' in full_config:
|
||||
config['commands'] = full_config.pop('commands')
|
||||
if 'commands_prefix' in full_config:
|
||||
config['prefixes'] = full_config.pop('commands_prefix')
|
||||
if 'commands_ignore_mention' in full_config:
|
||||
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
||||
return config
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||
|
||||
@staticmethod
|
||||
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||
full_command = message.text.split()[0]
|
||||
|
|
@ -33,9 +68,6 @@ class Command(Filter):
|
|||
|
||||
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||
|
||||
@dataclass
|
||||
class CommandObj:
|
||||
prefix: str = '/'
|
||||
|
|
@ -57,29 +89,36 @@ class Command(Filter):
|
|||
return line
|
||||
|
||||
|
||||
class CommandsFilter(BaseFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
"""
|
||||
key = 'commands'
|
||||
class CommandStart(Command):
|
||||
def __init__(self):
|
||||
super(CommandStart, self).__init__(['start'])
|
||||
|
||||
def __init__(self, dispatcher, commands):
|
||||
super().__init__(dispatcher)
|
||||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
self.commands = commands
|
||||
|
||||
async def check(self, message):
|
||||
return await Command.check_command(message, self.commands, '/')
|
||||
class CommandHelp(Command):
|
||||
def __init__(self):
|
||||
super(CommandHelp, self).__init__(['help'])
|
||||
|
||||
|
||||
class Text(Filter):
|
||||
"""
|
||||
Simple text filter
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[str] = None,
|
||||
startswith: Optional[str] = None,
|
||||
endswith: Optional[str] = None,
|
||||
ignore_case=False):
|
||||
"""
|
||||
Check text for one of pattern. Only one mode can be used in one filter.
|
||||
|
||||
:param equals:
|
||||
:param contains:
|
||||
:param startswith:
|
||||
:param endswith:
|
||||
:param ignore_case: case insensitive
|
||||
"""
|
||||
# Only one mode can be used. check it.
|
||||
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||
if check > 1:
|
||||
|
|
@ -98,8 +137,27 @@ class Text(Filter):
|
|||
self.startswith = startswith
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
text = message.text.lower() if self.ignore_case else message.text
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
if 'text' in full_config:
|
||||
return {'equals': full_config.pop('text')}
|
||||
elif 'text_contains' in full_config:
|
||||
return {'contains': full_config.pop('text_contains')}
|
||||
elif 'text_startswith' in full_config:
|
||||
return {'startswith': full_config.pop('text_startswith')}
|
||||
elif 'text_endswith' in full_config:
|
||||
return {'endswith': full_config.pop('text_endswith')}
|
||||
|
||||
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||
if isinstance(obj, Message):
|
||||
text = obj.text or obj.caption or ''
|
||||
elif isinstance(obj, CallbackQuery):
|
||||
text = obj.data
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.equals:
|
||||
return text == self.equals
|
||||
|
|
@ -113,24 +171,24 @@ class Text(Filter):
|
|||
return False
|
||||
|
||||
|
||||
class RegexpFilter(BaseFilter):
|
||||
class Regexp(Filter):
|
||||
"""
|
||||
Regexp filter for messages and callback query
|
||||
"""
|
||||
key = 'regexp'
|
||||
|
||||
def __init__(self, dispatcher, regexp):
|
||||
super().__init__(dispatcher)
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
def __init__(self, regexp):
|
||||
if not isinstance(regexp, re.Pattern):
|
||||
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
self.regexp = regexp
|
||||
|
||||
async def check(self, obj):
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
if 'regexp' in full_config:
|
||||
return {'regexp': full_config.pop('regexp')}
|
||||
|
||||
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||
if isinstance(obj, Message):
|
||||
if obj.text:
|
||||
match = self.regexp.search(obj.text)
|
||||
elif obj.caption:
|
||||
match = self.regexp.search(obj.caption)
|
||||
else:
|
||||
return False
|
||||
match = self.regexp.search(obj.text or obj.caption or '')
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
match = self.regexp.search(obj.data)
|
||||
else:
|
||||
|
|
@ -141,15 +199,14 @@ class RegexpFilter(BaseFilter):
|
|||
return False
|
||||
|
||||
|
||||
class RegexpCommandsFilter(BaseFilter):
|
||||
class RegexpCommandsFilter(BoundFilter):
|
||||
"""
|
||||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
key = 'regexp_commands'
|
||||
|
||||
def __init__(self, dispatcher, regexp_commands):
|
||||
super().__init__(dispatcher)
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
|
||||
async def check(self, message):
|
||||
|
|
@ -169,7 +226,7 @@ class RegexpCommandsFilter(BaseFilter):
|
|||
return False
|
||||
|
||||
|
||||
class ContentTypeFilter(BaseFilter):
|
||||
class ContentTypeFilter(BoundFilter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
|
@ -178,8 +235,7 @@ class ContentTypeFilter(BaseFilter):
|
|||
required = True
|
||||
default = types.ContentType.TEXT
|
||||
|
||||
def __init__(self, dispatcher, content_types):
|
||||
super().__init__(dispatcher)
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
|
||||
async def check(self, message):
|
||||
|
|
@ -187,7 +243,7 @@ class ContentTypeFilter(BaseFilter):
|
|||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class StateFilter(BaseFilter):
|
||||
class StateFilter(BoundFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
|
|
@ -199,7 +255,7 @@ class StateFilter(BaseFilter):
|
|||
def __init__(self, dispatcher, state):
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
|
||||
super().__init__(dispatcher)
|
||||
self.dispatcher = dispatcher
|
||||
states = []
|
||||
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||
state = [state, ]
|
||||
|
|
@ -237,7 +293,7 @@ class StateFilter(BaseFilter):
|
|||
return False
|
||||
|
||||
|
||||
class ExceptionsFilter(BaseFilter):
|
||||
class ExceptionsFilter(BoundFilter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -21,29 +21,35 @@ def wrap_async(func):
|
|||
return async_wrapper
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
async def check_filter(dispatcher, filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param dispatcher:
|
||||
:param filter_:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
kwargs = {}
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
spec = inspect.getfullargspec(filter_)
|
||||
if 'dispatcher' in spec:
|
||||
kwargs['dispatcher'] = dispatcher
|
||||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, AbstractFilter):
|
||||
return await filter_(*args)
|
||||
return await filter_(*args, **kwargs)
|
||||
else:
|
||||
return filter_(*args)
|
||||
return filter_(*args, **kwargs)
|
||||
|
||||
|
||||
async def check_filters(filters, args):
|
||||
async def check_filters(dispatcher, filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param dispatcher:
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
|
|
@ -51,7 +57,7 @@ async def check_filters(filters, args):
|
|||
data = {}
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args)
|
||||
f = await check_filter(dispatcher, filter_, args)
|
||||
if not f:
|
||||
raise FilterNotPassed()
|
||||
elif isinstance(f, dict):
|
||||
|
|
@ -89,11 +95,16 @@ class FilterRecord:
|
|||
return
|
||||
config = self.resolver(full_config)
|
||||
if config:
|
||||
if 'dispatcher' not in config:
|
||||
spec = inspect.getfullargspec(self.callback)
|
||||
if 'dispatcher' in spec.args:
|
||||
config['dispatcher'] = dispatcher
|
||||
|
||||
for key in config:
|
||||
if key in full_config:
|
||||
full_config.pop(key)
|
||||
|
||||
return self.callback(dispatcher, **config)
|
||||
return self.callback(**config)
|
||||
|
||||
def _check_event_handler(self, event_handler) -> bool:
|
||||
if self.event_handlers:
|
||||
|
|
@ -108,8 +119,6 @@ class AbstractFilter(abc.ABC):
|
|||
Abstract class for custom filters
|
||||
"""
|
||||
|
||||
key = None
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
|
|
@ -138,13 +147,29 @@ class AbstractFilter(abc.ABC):
|
|||
return NotFilter(self)
|
||||
|
||||
def __and__(self, other):
|
||||
if isinstance(self, AndFilter):
|
||||
self.append(other)
|
||||
return self
|
||||
return AndFilter(self, other)
|
||||
|
||||
def __or__(self, other):
|
||||
if isinstance(self, OrFilter):
|
||||
self.append(other)
|
||||
return self
|
||||
return OrFilter(self, other)
|
||||
|
||||
|
||||
class BaseFilter(AbstractFilter):
|
||||
class Filter(AbstractFilter):
|
||||
"""
|
||||
You can make subclasses of that class for custom filters
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
pass
|
||||
|
||||
|
||||
class BoundFilter(Filter):
|
||||
"""
|
||||
Base class for filters with default validator
|
||||
"""
|
||||
|
|
@ -152,10 +177,6 @@ class BaseFilter(AbstractFilter):
|
|||
required = False
|
||||
default = None
|
||||
|
||||
def __init__(self, dispatcher, **config):
|
||||
self.dispatcher = dispatcher
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||
if cls.key is not None:
|
||||
|
|
@ -165,33 +186,65 @@ class BaseFilter(AbstractFilter):
|
|||
return {cls.key: cls.default}
|
||||
|
||||
|
||||
class Filter(AbstractFilter):
|
||||
class _LogicFilter(Filter):
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
raise RuntimeError('This filter can\'t be passed as kwargs')
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||
raise ValueError('That filter can\'t be used in filters factory!')
|
||||
|
||||
|
||||
class NotFilter(Filter):
|
||||
class NotFilter(_LogicFilter):
|
||||
def __init__(self, target):
|
||||
self.target = wrap_async(target)
|
||||
|
||||
async def check(self, *args):
|
||||
return await self.target(*args)
|
||||
return not bool(await self.target(*args))
|
||||
|
||||
|
||||
class AndFilter(Filter):
|
||||
def __init__(self, target, target2):
|
||||
self.target = wrap_async(target)
|
||||
self.target2 = wrap_async(target2)
|
||||
class AndFilter(_LogicFilter):
|
||||
|
||||
def __init__(self, *targets):
|
||||
self.targets = list(wrap_async(target) for target in targets)
|
||||
|
||||
async def check(self, *args):
|
||||
return (await self.target(*args)) and (await self.target2(*args))
|
||||
"""
|
||||
All filters must return a positive result
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
data = {}
|
||||
for target in self.targets:
|
||||
result = await target(*args)
|
||||
if not result:
|
||||
return False
|
||||
if isinstance(result, dict):
|
||||
data.update(result)
|
||||
if not data:
|
||||
return True
|
||||
return data
|
||||
|
||||
def append(self, target):
|
||||
self.targets.append(wrap_async(target))
|
||||
|
||||
|
||||
class OrFilter(Filter):
|
||||
def __init__(self, target, target2):
|
||||
self.target = wrap_async(target)
|
||||
self.target2 = wrap_async(target2)
|
||||
class OrFilter(_LogicFilter):
|
||||
def __init__(self, *targets):
|
||||
self.targets = list(wrap_async(target) for target in targets)
|
||||
|
||||
async def check(self, *args):
|
||||
return (await self.target(*args)) or (await self.target2(*args))
|
||||
"""
|
||||
One of filters must return a positive result
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
for target in self.targets:
|
||||
result = await target(*args)
|
||||
if result:
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return True
|
||||
return False
|
||||
|
||||
def append(self, target):
|
||||
self.targets.append(wrap_async(target))
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class Handler:
|
|||
try:
|
||||
for filters, handler in self.handlers:
|
||||
try:
|
||||
data.update(await check_filters(filters, args))
|
||||
data.update(await check_filters(self.dispatcher, filters, args))
|
||||
except FilterNotPassed:
|
||||
continue
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue