mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-13 18:43:53 +00:00
417 lines
10 KiB
Python
417 lines
10 KiB
Python
import asyncio
|
|
import copy
|
|
import inspect
|
|
import re
|
|
import typing
|
|
|
|
from ..types import ContentType
|
|
from ..utils import context
|
|
from ..utils.helper import Helper, HelperMode, Item
|
|
|
|
USER_STATE = 'USER_STATE'
|
|
|
|
|
|
class FiltersFactory:
|
|
def __init__(self, dispatcher):
|
|
self._dispatcher = dispatcher
|
|
self._filters = []
|
|
|
|
@property
|
|
def _default_filters(self):
|
|
return tuple(filter(lambda item: item[-1], self._filters))
|
|
|
|
def bind(self, filter_, *args, default=False, with_dispatcher=False):
|
|
self._filters.append(FilterConfig(self._dispatcher, filter_, default=default, with_dispatcher=with_dispatcher))
|
|
|
|
def unbind(self, filter_):
|
|
for item in self._filters:
|
|
if filter_ is item[0]:
|
|
self._filters.remove(item)
|
|
return True
|
|
raise ValueError(f'{filter_} is not binded.')
|
|
|
|
def replace(self, original, new):
|
|
for item in self._filters:
|
|
if original is item[0]:
|
|
item[0] = new
|
|
return True
|
|
raise ValueError(f'{original} is not binded.')
|
|
|
|
def parse(self, *args, **kwargs):
|
|
"""
|
|
Generate filters list
|
|
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
used = []
|
|
filters = []
|
|
|
|
filters.extend(args)
|
|
|
|
# Registered filters filters
|
|
for filterconfig in self._filters:
|
|
pass
|
|
|
|
# Not registered filters
|
|
for key, filter_ in kwargs.items():
|
|
if isinstance(filter_, Filter):
|
|
filters.append(filter_)
|
|
used.append(filter_.__class__)
|
|
else:
|
|
raise ValueError(f"Unknown filter with key '{key}'")
|
|
|
|
return filters
|
|
|
|
|
|
class FilterConfig:
|
|
def __init__(self, dispatcher, filter_: typing.Callable,
|
|
default: bool = False, with_dispatcher: bool = False,
|
|
args: typing.Union[tuple, set, list] = ()):
|
|
self.dispatcher = dispatcher
|
|
self.filter = filter_
|
|
self.default = default
|
|
self.with_dispatcher = with_dispatcher
|
|
self.args = args
|
|
|
|
def _check_list(self, config):
|
|
result = {}
|
|
accept = True
|
|
|
|
for item in self.args:
|
|
value = config.pop(item, None)
|
|
if value is None:
|
|
accept = False
|
|
break
|
|
result[item] = value
|
|
|
|
return accept or self.default, result
|
|
|
|
def _check_dict(self, config):
|
|
result = {}
|
|
accept = True
|
|
|
|
for key, type_ in self.args:
|
|
value = config.pop(key, None)
|
|
if value is None:
|
|
accept = False
|
|
break
|
|
if type_ is bool:
|
|
|
|
return accept or self.default, result
|
|
|
|
def check(self, config):
|
|
if isinstance(config, dict):
|
|
return self._check_dict(config)
|
|
else:
|
|
return self._check_list(config)
|
|
|
|
def parse(self, config):
|
|
pass
|
|
|
|
def configure(self, config=None):
|
|
if config is None:
|
|
config = {}
|
|
if self.with_dispatcher:
|
|
if config:
|
|
config = copy.deepcopy(config)
|
|
config['dispatcher'] = self.dispatcher
|
|
return self.filter(**config)
|
|
|
|
|
|
async def check_filter(filter_, args):
|
|
"""
|
|
Helper for executing filter
|
|
|
|
:param filter_:
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
if not callable(filter_):
|
|
raise TypeError('Filter must be callable and/or awaitable!')
|
|
|
|
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
|
return await filter_(*args)
|
|
else:
|
|
return filter_(*args)
|
|
|
|
|
|
async def check_filters(filters, args):
|
|
"""
|
|
Check list of filters
|
|
|
|
:param filters:
|
|
:param args:
|
|
:return:
|
|
"""
|
|
if filters is not None:
|
|
for filter_ in filters:
|
|
f = await check_filter(filter_, args)
|
|
if not f:
|
|
return False
|
|
return True
|
|
|
|
|
|
class Filter:
|
|
"""
|
|
Base class for filters
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._args = args
|
|
self._kwargs = kwargs
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.check(*args, **kwargs)
|
|
|
|
def check(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class AsyncFilter(Filter):
|
|
"""
|
|
Base class for asynchronous filters
|
|
"""
|
|
|
|
def __aiter__(self):
|
|
return None
|
|
|
|
def __await__(self):
|
|
return self.check
|
|
|
|
async def check(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
class AnyFilter(AsyncFilter):
|
|
"""
|
|
One filter from many
|
|
"""
|
|
|
|
def __init__(self, *filters: callable):
|
|
self.filters = filters
|
|
super().__init__()
|
|
|
|
async def check(self, *args):
|
|
f = (check_filter(filter_, args) for filter_ in self.filters)
|
|
return any(await asyncio.gather(*f))
|
|
|
|
|
|
class NotFilter(AsyncFilter):
|
|
"""
|
|
Reverse filter
|
|
"""
|
|
|
|
def __init__(self, filter_: callable):
|
|
self.filter = filter_
|
|
super().__init__()
|
|
|
|
async def check(self, *args):
|
|
return not await check_filter(self.filter, args)
|
|
|
|
|
|
class CommandsFilter(AsyncFilter):
|
|
"""
|
|
Check commands in message
|
|
"""
|
|
|
|
def __init__(self, commands):
|
|
self.commands = commands
|
|
super().__init__()
|
|
|
|
async def check(self, message):
|
|
if not message.is_command():
|
|
return False
|
|
|
|
command = message.text.split()[0][1:]
|
|
command, _, mention = command.partition('@')
|
|
|
|
if mention and mention != (await message.bot.me).username:
|
|
return False
|
|
|
|
if command not in self.commands:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class RegexpFilter(Filter):
|
|
"""
|
|
Regexp filter for messages
|
|
"""
|
|
|
|
def __init__(self, regexp):
|
|
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
|
super().__init__()
|
|
|
|
def check(self, message):
|
|
if message.text:
|
|
return bool(self.regexp.search(message.text))
|
|
|
|
|
|
class RegexpCommandsFilter(AsyncFilter):
|
|
"""
|
|
Check commands by regexp in message
|
|
"""
|
|
|
|
def __init__(self, regexp_commands):
|
|
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
|
super().__init__()
|
|
|
|
async def check(self, message):
|
|
if not message.is_command():
|
|
return False
|
|
|
|
command = message.text.split()[0][1:]
|
|
command, _, mention = command.partition('@')
|
|
|
|
if mention and mention != (await message.bot.me).username:
|
|
return False
|
|
|
|
for command in self.regexp_commands:
|
|
search = command.search(message.text)
|
|
if search:
|
|
message.conf['regexp_command'] = search
|
|
return True
|
|
return False
|
|
|
|
|
|
class ContentTypeFilter(Filter):
|
|
"""
|
|
Check message content type
|
|
"""
|
|
|
|
def __init__(self, content_types):
|
|
self.content_types = content_types
|
|
super().__init__()
|
|
|
|
def check(self, message):
|
|
return ContentType.ANY[0] in self.content_types or \
|
|
message.content_type in self.content_types
|
|
|
|
|
|
class CancelFilter(Filter):
|
|
"""
|
|
Find cancel in message text
|
|
"""
|
|
|
|
def __init__(self, cancel_set=None):
|
|
if cancel_set is None:
|
|
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
|
self.cancel_set = cancel_set
|
|
super().__init__()
|
|
|
|
def check(self, message):
|
|
if message.text:
|
|
return message.text.lower() in self.cancel_set
|
|
|
|
|
|
class StateFilter(AsyncFilter):
|
|
"""
|
|
Check user state
|
|
"""
|
|
|
|
def __init__(self, dispatcher, state):
|
|
self.dispatcher = dispatcher
|
|
if isinstance(state, str):
|
|
state = (state,)
|
|
self.state = state
|
|
super().__init__()
|
|
|
|
def get_target(self, obj):
|
|
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
|
|
|
async def check(self, obj):
|
|
if self.state == '*':
|
|
return True
|
|
|
|
if context.check_value(USER_STATE):
|
|
context_state = context.get_value(USER_STATE)
|
|
return self.state == context_state
|
|
else:
|
|
chat, user = self.get_target(obj)
|
|
|
|
if chat or user:
|
|
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
|
return False
|
|
|
|
|
|
class StatesListFilter(StateFilter):
|
|
"""
|
|
List of states
|
|
"""
|
|
|
|
async def check(self, obj):
|
|
chat, user = self.get_target(obj)
|
|
|
|
if chat or user:
|
|
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
|
return False
|
|
|
|
|
|
class ExceptionsFilter(Filter):
|
|
"""
|
|
Filter for exceptions
|
|
"""
|
|
|
|
def __init__(self, exception):
|
|
self.exception = exception
|
|
super().__init__()
|
|
|
|
def check(self, dispatcher, update, exception):
|
|
try:
|
|
raise exception
|
|
except self.exception:
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
|
|
def generate_default_filters(dispatcher, *args, **kwargs):
|
|
"""
|
|
Prepare filters
|
|
|
|
:param dispatcher:
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
filters_set = []
|
|
|
|
for name, filter_ in kwargs.items():
|
|
if filter_ is None and name != DefaultFilters.STATE:
|
|
continue
|
|
if name == DefaultFilters.COMMANDS:
|
|
if isinstance(filter_, str):
|
|
filters_set.append(CommandsFilter([filter_]))
|
|
else:
|
|
filters_set.append(CommandsFilter(filter_))
|
|
elif name == DefaultFilters.REGEXP:
|
|
filters_set.append(RegexpFilter(filter_))
|
|
elif name == DefaultFilters.CONTENT_TYPES:
|
|
filters_set.append(ContentTypeFilter(filter_))
|
|
elif name == DefaultFilters.FUNC:
|
|
filters_set.append(filter_)
|
|
elif name == DefaultFilters.STATE:
|
|
if isinstance(filter_, (list, set, tuple)):
|
|
filters_set.append(StatesListFilter(dispatcher, filter_))
|
|
else:
|
|
filters_set.append(StateFilter(dispatcher, filter_))
|
|
elif isinstance(filter_, Filter):
|
|
filters_set.append(filter_)
|
|
|
|
filters_set += list(args)
|
|
|
|
return filters_set
|
|
|
|
|
|
class DefaultFilters(Helper):
|
|
mode = HelperMode.snake_case
|
|
|
|
COMMANDS = Item() # commands
|
|
REGEXP = Item() # regexp
|
|
CONTENT_TYPES = Item() # content_type
|
|
FUNC = Item() # func
|
|
STATE = Item() # state
|