mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Refactor filters.
This commit is contained in:
parent
539c76a062
commit
f957883082
5 changed files with 191 additions and 80 deletions
|
|
@ -6,8 +6,9 @@ import time
|
||||||
import typing
|
import typing
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
|
from aiogram.dispatcher.filters import Command
|
||||||
RegexpFilter, StateFilter
|
from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
|
||||||
|
Regexp, StateFilter
|
||||||
from .handler import Handler
|
from .handler import Handler
|
||||||
from .middlewares import MiddlewareManager
|
from .middlewares import MiddlewareManager
|
||||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||||
|
|
@ -84,10 +85,10 @@ class Dispatcher:
|
||||||
self.message_handlers, self.edited_message_handlers,
|
self.message_handlers, self.edited_message_handlers,
|
||||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||||
])
|
])
|
||||||
filters_factory.bind(CommandsFilter, event_handlers=[
|
filters_factory.bind(Command, event_handlers=[
|
||||||
self.message_handlers, self.edited_message_handlers
|
self.message_handlers, self.edited_message_handlers
|
||||||
])
|
])
|
||||||
filters_factory.bind(RegexpFilter, event_handlers=[
|
filters_factory.bind(Regexp, event_handlers=[
|
||||||
self.message_handlers, self.edited_message_handlers,
|
self.message_handlers, self.edited_message_handlers,
|
||||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||||
self.callback_query_handlers
|
self.callback_query_handlers
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
|
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
|
||||||
StateFilter, Text
|
RegexpCommandsFilter, StateFilter, Text
|
||||||
from .factory import FiltersFactory
|
from .factory import FiltersFactory
|
||||||
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AbstractFilter',
|
'AbstractFilter',
|
||||||
'BaseFilter',
|
'BoundFilter',
|
||||||
'Command',
|
'Command',
|
||||||
'CommandsFilter',
|
'CommandStart',
|
||||||
|
'CommandHelp',
|
||||||
'ContentTypeFilter',
|
'ContentTypeFilter',
|
||||||
'ExceptionsFilter',
|
'ExceptionsFilter',
|
||||||
'Filter',
|
'Filter',
|
||||||
|
|
@ -15,7 +16,7 @@ __all__ = [
|
||||||
'FilterRecord',
|
'FilterRecord',
|
||||||
'FiltersFactory',
|
'FiltersFactory',
|
||||||
'RegexpCommandsFilter',
|
'RegexpCommandsFilter',
|
||||||
'RegexpFilter',
|
'Regexp',
|
||||||
'StateFilter',
|
'StateFilter',
|
||||||
'Text',
|
'Text',
|
||||||
'check_filter',
|
'check_filter',
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,30 @@ import inspect
|
||||||
import re
|
import re
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter
|
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
|
||||||
from aiogram.types import CallbackQuery, ContentType, Message
|
from aiogram.types import CallbackQuery, ContentType, Message
|
||||||
|
|
||||||
|
|
||||||
class Command(Filter):
|
class Command(Filter):
|
||||||
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
|
"""
|
||||||
|
You can handle commands by using this filter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, commands: Union[Iterable, str],
|
||||||
|
prefixes: Union[Iterable, str] = '/',
|
||||||
|
ignore_case: bool = True,
|
||||||
|
ignore_mention: bool = False):
|
||||||
|
"""
|
||||||
|
Filter can be initialized from filters factory or by simply creating instance of this class
|
||||||
|
|
||||||
|
:param commands: command or list of commands
|
||||||
|
:param prefixes:
|
||||||
|
:param ignore_case:
|
||||||
|
:param ignore_mention:
|
||||||
|
"""
|
||||||
if isinstance(commands, str):
|
if isinstance(commands, str):
|
||||||
commands = (commands,)
|
commands = (commands,)
|
||||||
|
|
||||||
|
|
@ -19,6 +34,26 @@ class Command(Filter):
|
||||||
self.ignore_case = ignore_case
|
self.ignore_case = ignore_case
|
||||||
self.ignore_mention = ignore_mention
|
self.ignore_mention = ignore_mention
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Validator for filters factory
|
||||||
|
|
||||||
|
:param full_config:
|
||||||
|
:return: config or empty dict
|
||||||
|
"""
|
||||||
|
config = {}
|
||||||
|
if 'commands' in full_config:
|
||||||
|
config['commands'] = full_config.pop('commands')
|
||||||
|
if 'commands_prefix' in full_config:
|
||||||
|
config['prefixes'] = full_config.pop('commands_prefix')
|
||||||
|
if 'commands_ignore_mention' in full_config:
|
||||||
|
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def check(self, message: types.Message):
|
||||||
|
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||||
full_command = message.text.split()[0]
|
full_command = message.text.split()[0]
|
||||||
|
|
@ -33,9 +68,6 @@ class Command(Filter):
|
||||||
|
|
||||||
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||||
|
|
||||||
async def check(self, message: types.Message):
|
|
||||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandObj:
|
class CommandObj:
|
||||||
prefix: str = '/'
|
prefix: str = '/'
|
||||||
|
|
@ -57,29 +89,36 @@ class Command(Filter):
|
||||||
return line
|
return line
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(BaseFilter):
|
class CommandStart(Command):
|
||||||
"""
|
def __init__(self):
|
||||||
Check commands in message
|
super(CommandStart, self).__init__(['start'])
|
||||||
"""
|
|
||||||
key = 'commands'
|
|
||||||
|
|
||||||
def __init__(self, dispatcher, commands):
|
|
||||||
super().__init__(dispatcher)
|
|
||||||
if isinstance(commands, str):
|
|
||||||
commands = (commands,)
|
|
||||||
self.commands = commands
|
|
||||||
|
|
||||||
async def check(self, message):
|
class CommandHelp(Command):
|
||||||
return await Command.check_command(message, self.commands, '/')
|
def __init__(self):
|
||||||
|
super(CommandHelp, self).__init__(['help'])
|
||||||
|
|
||||||
|
|
||||||
class Text(Filter):
|
class Text(Filter):
|
||||||
|
"""
|
||||||
|
Simple text filter
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
equals: Optional[str] = None,
|
equals: Optional[str] = None,
|
||||||
contains: Optional[str] = None,
|
contains: Optional[str] = None,
|
||||||
startswith: Optional[str] = None,
|
startswith: Optional[str] = None,
|
||||||
endswith: Optional[str] = None,
|
endswith: Optional[str] = None,
|
||||||
ignore_case=False):
|
ignore_case=False):
|
||||||
|
"""
|
||||||
|
Check text for one of pattern. Only one mode can be used in one filter.
|
||||||
|
|
||||||
|
:param equals:
|
||||||
|
:param contains:
|
||||||
|
:param startswith:
|
||||||
|
:param endswith:
|
||||||
|
:param ignore_case: case insensitive
|
||||||
|
"""
|
||||||
# Only one mode can be used. check it.
|
# Only one mode can be used. check it.
|
||||||
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||||
if check > 1:
|
if check > 1:
|
||||||
|
|
@ -98,8 +137,27 @@ class Text(Filter):
|
||||||
self.startswith = startswith
|
self.startswith = startswith
|
||||||
self.ignore_case = ignore_case
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
async def check(self, message: types.Message):
|
@classmethod
|
||||||
text = message.text.lower() if self.ignore_case else message.text
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
|
if 'text' in full_config:
|
||||||
|
return {'equals': full_config.pop('text')}
|
||||||
|
elif 'text_contains' in full_config:
|
||||||
|
return {'contains': full_config.pop('text_contains')}
|
||||||
|
elif 'text_startswith' in full_config:
|
||||||
|
return {'startswith': full_config.pop('text_startswith')}
|
||||||
|
elif 'text_endswith' in full_config:
|
||||||
|
return {'endswith': full_config.pop('text_endswith')}
|
||||||
|
|
||||||
|
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||||
|
if isinstance(obj, Message):
|
||||||
|
text = obj.text or obj.caption or ''
|
||||||
|
elif isinstance(obj, CallbackQuery):
|
||||||
|
text = obj.data
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.ignore_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
if self.equals:
|
if self.equals:
|
||||||
return text == self.equals
|
return text == self.equals
|
||||||
|
|
@ -113,24 +171,24 @@ class Text(Filter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RegexpFilter(BaseFilter):
|
class Regexp(Filter):
|
||||||
"""
|
"""
|
||||||
Regexp filter for messages and callback query
|
Regexp filter for messages and callback query
|
||||||
"""
|
"""
|
||||||
key = 'regexp'
|
|
||||||
|
|
||||||
def __init__(self, dispatcher, regexp):
|
def __init__(self, regexp):
|
||||||
super().__init__(dispatcher)
|
if not isinstance(regexp, re.Pattern):
|
||||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
self.regexp = regexp
|
||||||
|
|
||||||
async def check(self, obj):
|
@classmethod
|
||||||
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
|
if 'regexp' in full_config:
|
||||||
|
return {'regexp': full_config.pop('regexp')}
|
||||||
|
|
||||||
|
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||||
if isinstance(obj, Message):
|
if isinstance(obj, Message):
|
||||||
if obj.text:
|
match = self.regexp.search(obj.text or obj.caption or '')
|
||||||
match = self.regexp.search(obj.text)
|
|
||||||
elif obj.caption:
|
|
||||||
match = self.regexp.search(obj.caption)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||||
match = self.regexp.search(obj.data)
|
match = self.regexp.search(obj.data)
|
||||||
else:
|
else:
|
||||||
|
|
@ -141,15 +199,14 @@ class RegexpFilter(BaseFilter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RegexpCommandsFilter(BaseFilter):
|
class RegexpCommandsFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Check commands by regexp in message
|
Check commands by regexp in message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = 'regexp_commands'
|
key = 'regexp_commands'
|
||||||
|
|
||||||
def __init__(self, dispatcher, regexp_commands):
|
def __init__(self, regexp_commands):
|
||||||
super().__init__(dispatcher)
|
|
||||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
|
|
@ -169,7 +226,7 @@ class RegexpCommandsFilter(BaseFilter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeFilter(BaseFilter):
|
class ContentTypeFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Check message content type
|
Check message content type
|
||||||
"""
|
"""
|
||||||
|
|
@ -178,8 +235,7 @@ class ContentTypeFilter(BaseFilter):
|
||||||
required = True
|
required = True
|
||||||
default = types.ContentType.TEXT
|
default = types.ContentType.TEXT
|
||||||
|
|
||||||
def __init__(self, dispatcher, content_types):
|
def __init__(self, content_types):
|
||||||
super().__init__(dispatcher)
|
|
||||||
self.content_types = content_types
|
self.content_types = content_types
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
|
|
@ -187,7 +243,7 @@ class ContentTypeFilter(BaseFilter):
|
||||||
message.content_type in self.content_types
|
message.content_type in self.content_types
|
||||||
|
|
||||||
|
|
||||||
class StateFilter(BaseFilter):
|
class StateFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Check user state
|
Check user state
|
||||||
"""
|
"""
|
||||||
|
|
@ -199,7 +255,7 @@ class StateFilter(BaseFilter):
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
|
|
||||||
super().__init__(dispatcher)
|
self.dispatcher = dispatcher
|
||||||
states = []
|
states = []
|
||||||
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||||
state = [state, ]
|
state = [state, ]
|
||||||
|
|
@ -237,7 +293,7 @@ class StateFilter(BaseFilter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ExceptionsFilter(BaseFilter):
|
class ExceptionsFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Filter for exceptions
|
Filter for exceptions
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -21,29 +21,35 @@ def wrap_async(func):
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args):
|
async def check_filter(dispatcher, filter_, args):
|
||||||
"""
|
"""
|
||||||
Helper for executing filter
|
Helper for executing filter
|
||||||
|
|
||||||
|
:param dispatcher:
|
||||||
:param filter_:
|
:param filter_:
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
kwargs = {}
|
||||||
if not callable(filter_):
|
if not callable(filter_):
|
||||||
raise TypeError('Filter must be callable and/or awaitable!')
|
raise TypeError('Filter must be callable and/or awaitable!')
|
||||||
|
|
||||||
|
spec = inspect.getfullargspec(filter_)
|
||||||
|
if 'dispatcher' in spec:
|
||||||
|
kwargs['dispatcher'] = dispatcher
|
||||||
if inspect.isawaitable(filter_) \
|
if inspect.isawaitable(filter_) \
|
||||||
or inspect.iscoroutinefunction(filter_) \
|
or inspect.iscoroutinefunction(filter_) \
|
||||||
or isinstance(filter_, AbstractFilter):
|
or isinstance(filter_, AbstractFilter):
|
||||||
return await filter_(*args)
|
return await filter_(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return filter_(*args)
|
return filter_(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def check_filters(filters, args):
|
async def check_filters(dispatcher, filters, args):
|
||||||
"""
|
"""
|
||||||
Check list of filters
|
Check list of filters
|
||||||
|
|
||||||
|
:param dispatcher:
|
||||||
:param filters:
|
:param filters:
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
|
|
@ -51,7 +57,7 @@ async def check_filters(filters, args):
|
||||||
data = {}
|
data = {}
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
for filter_ in filters:
|
for filter_ in filters:
|
||||||
f = await check_filter(filter_, args)
|
f = await check_filter(dispatcher, filter_, args)
|
||||||
if not f:
|
if not f:
|
||||||
raise FilterNotPassed()
|
raise FilterNotPassed()
|
||||||
elif isinstance(f, dict):
|
elif isinstance(f, dict):
|
||||||
|
|
@ -89,11 +95,16 @@ class FilterRecord:
|
||||||
return
|
return
|
||||||
config = self.resolver(full_config)
|
config = self.resolver(full_config)
|
||||||
if config:
|
if config:
|
||||||
|
if 'dispatcher' not in config:
|
||||||
|
spec = inspect.getfullargspec(self.callback)
|
||||||
|
if 'dispatcher' in spec.args:
|
||||||
|
config['dispatcher'] = dispatcher
|
||||||
|
|
||||||
for key in config:
|
for key in config:
|
||||||
if key in full_config:
|
if key in full_config:
|
||||||
full_config.pop(key)
|
full_config.pop(key)
|
||||||
|
|
||||||
return self.callback(dispatcher, **config)
|
return self.callback(**config)
|
||||||
|
|
||||||
def _check_event_handler(self, event_handler) -> bool:
|
def _check_event_handler(self, event_handler) -> bool:
|
||||||
if self.event_handlers:
|
if self.event_handlers:
|
||||||
|
|
@ -108,8 +119,6 @@ class AbstractFilter(abc.ABC):
|
||||||
Abstract class for custom filters
|
Abstract class for custom filters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
|
@ -138,13 +147,29 @@ class AbstractFilter(abc.ABC):
|
||||||
return NotFilter(self)
|
return NotFilter(self)
|
||||||
|
|
||||||
def __and__(self, other):
|
def __and__(self, other):
|
||||||
|
if isinstance(self, AndFilter):
|
||||||
|
self.append(other)
|
||||||
|
return self
|
||||||
return AndFilter(self, other)
|
return AndFilter(self, other)
|
||||||
|
|
||||||
def __or__(self, other):
|
def __or__(self, other):
|
||||||
|
if isinstance(self, OrFilter):
|
||||||
|
self.append(other)
|
||||||
|
return self
|
||||||
return OrFilter(self, other)
|
return OrFilter(self, other)
|
||||||
|
|
||||||
|
|
||||||
class BaseFilter(AbstractFilter):
|
class Filter(AbstractFilter):
|
||||||
|
"""
|
||||||
|
You can make subclasses of that class for custom filters
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoundFilter(Filter):
|
||||||
"""
|
"""
|
||||||
Base class for filters with default validator
|
Base class for filters with default validator
|
||||||
"""
|
"""
|
||||||
|
|
@ -152,10 +177,6 @@ class BaseFilter(AbstractFilter):
|
||||||
required = False
|
required = False
|
||||||
default = None
|
default = None
|
||||||
|
|
||||||
def __init__(self, dispatcher, **config):
|
|
||||||
self.dispatcher = dispatcher
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||||
if cls.key is not None:
|
if cls.key is not None:
|
||||||
|
|
@ -165,33 +186,65 @@ class BaseFilter(AbstractFilter):
|
||||||
return {cls.key: cls.default}
|
return {cls.key: cls.default}
|
||||||
|
|
||||||
|
|
||||||
class Filter(AbstractFilter):
|
class _LogicFilter(Filter):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||||
raise RuntimeError('This filter can\'t be passed as kwargs')
|
raise ValueError('That filter can\'t be used in filters factory!')
|
||||||
|
|
||||||
|
|
||||||
class NotFilter(Filter):
|
class NotFilter(_LogicFilter):
|
||||||
def __init__(self, target):
|
def __init__(self, target):
|
||||||
self.target = wrap_async(target)
|
self.target = wrap_async(target)
|
||||||
|
|
||||||
async def check(self, *args):
|
async def check(self, *args):
|
||||||
return await self.target(*args)
|
return not bool(await self.target(*args))
|
||||||
|
|
||||||
|
|
||||||
class AndFilter(Filter):
|
class AndFilter(_LogicFilter):
|
||||||
def __init__(self, target, target2):
|
|
||||||
self.target = wrap_async(target)
|
def __init__(self, *targets):
|
||||||
self.target2 = wrap_async(target2)
|
self.targets = list(wrap_async(target) for target in targets)
|
||||||
|
|
||||||
async def check(self, *args):
|
async def check(self, *args):
|
||||||
return (await self.target(*args)) and (await self.target2(*args))
|
"""
|
||||||
|
All filters must return a positive result
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
for target in self.targets:
|
||||||
|
result = await target(*args)
|
||||||
|
if not result:
|
||||||
|
return False
|
||||||
|
if isinstance(result, dict):
|
||||||
|
data.update(result)
|
||||||
|
if not data:
|
||||||
|
return True
|
||||||
|
return data
|
||||||
|
|
||||||
|
def append(self, target):
|
||||||
|
self.targets.append(wrap_async(target))
|
||||||
|
|
||||||
|
|
||||||
class OrFilter(Filter):
|
class OrFilter(_LogicFilter):
|
||||||
def __init__(self, target, target2):
|
def __init__(self, *targets):
|
||||||
self.target = wrap_async(target)
|
self.targets = list(wrap_async(target) for target in targets)
|
||||||
self.target2 = wrap_async(target2)
|
|
||||||
|
|
||||||
async def check(self, *args):
|
async def check(self, *args):
|
||||||
return (await self.target(*args)) or (await self.target2(*args))
|
"""
|
||||||
|
One of filters must return a positive result
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for target in self.targets:
|
||||||
|
result = await target(*args)
|
||||||
|
if result:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def append(self, target):
|
||||||
|
self.targets.append(wrap_async(target))
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class Handler:
|
||||||
try:
|
try:
|
||||||
for filters, handler in self.handlers:
|
for filters, handler in self.handlers:
|
||||||
try:
|
try:
|
||||||
data.update(await check_filters(filters, args))
|
data.update(await check_filters(self.dispatcher, filters, args))
|
||||||
except FilterNotPassed:
|
except FilterNotPassed:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue