mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 09:55:21 +00:00
Implement filters factory.
This commit is contained in:
parent
a3856e33bd
commit
8d7a00204d
2 changed files with 121 additions and 19 deletions
|
|
@ -5,8 +5,8 @@ import logging
|
|||
import time
|
||||
import typing
|
||||
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \
|
||||
USER_STATE, generate_default_filters
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \
|
||||
generate_default_filters
|
||||
from .handler import CancelHandler, Handler, SkipHandler
|
||||
from .middlewares import MiddlewareManager
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||
|
|
@ -37,12 +37,15 @@ class Dispatcher:
|
|||
|
||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||
run_tasks_by_default: bool = False,
|
||||
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
||||
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False,
|
||||
filters_factory=None):
|
||||
|
||||
if loop is None:
|
||||
loop = bot.loop
|
||||
if storage is None:
|
||||
storage = DisabledStorage()
|
||||
if filters_factory is None:
|
||||
filters_factory = FiltersFactory(self)
|
||||
|
||||
self.bot: Bot = bot
|
||||
self.loop = loop
|
||||
|
|
@ -54,6 +57,7 @@ class Dispatcher:
|
|||
|
||||
self.last_update_id = 0
|
||||
|
||||
self.filters_factory: FiltersFactory = filters_factory
|
||||
self.updates_handler = Handler(self, middleware_key='update')
|
||||
self.message_handlers = Handler(self, middleware_key='message')
|
||||
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||
|
|
@ -74,6 +78,12 @@ class Dispatcher:
|
|||
self._closed = True
|
||||
self._close_waiter = loop.create_future()
|
||||
|
||||
filters_factory.bind(filters.CommandsFilter, 'commands')
|
||||
filters_factory.bind(filters.RegexpFilter, 'regexp')
|
||||
filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands')
|
||||
filters_factory.bind(filters.ContentTypeFilter, 'content_types')
|
||||
filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True)
|
||||
|
||||
def __del__(self):
|
||||
self.stop_polling()
|
||||
|
||||
|
|
@ -310,8 +320,8 @@ class Dispatcher:
|
|||
"""
|
||||
return self._polling
|
||||
|
||||
def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None,
|
||||
state=None, custom_filters=None, run_task=None, **kwargs):
|
||||
def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
|
||||
func=None, state=None, run_task=None, **kwargs):
|
||||
"""
|
||||
Register handler for message
|
||||
|
||||
|
|
@ -334,23 +344,23 @@ class Dispatcher:
|
|||
:param content_types: List of content types.
|
||||
:param func: custom any callable object
|
||||
:param custom_filters: list of custom filters
|
||||
:param run_task:
|
||||
:param kwargs:
|
||||
:param state:
|
||||
:return: decorated function
|
||||
"""
|
||||
if content_types is None:
|
||||
content_types = ContentType.TEXT
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
if func is not None:
|
||||
custom_filters = list(custom_filters)
|
||||
custom_filters.append(func)
|
||||
|
||||
filters_set = generate_default_filters(self,
|
||||
*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_types=content_types,
|
||||
func=func,
|
||||
state=state,
|
||||
**kwargs)
|
||||
filters_set = self.filters_factory.parse(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_types=content_types,
|
||||
state=state,
|
||||
**kwargs)
|
||||
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
|
||||
|
||||
def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None,
|
||||
|
|
@ -426,9 +436,9 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def decorator(callback):
|
||||
self.register_message_handler(callback,
|
||||
self.register_message_handler(callback, *custom_filters,
|
||||
commands=commands, regexp=regexp, content_types=content_types,
|
||||
func=func, state=state, custom_filters=custom_filters, run_task=run_task,
|
||||
func=func, state=state, run_task=run_task,
|
||||
**kwargs)
|
||||
return callback
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,84 @@ from ..utils.helper import Helper, HelperMode, Item
|
|||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
class FiltersFactory:
|
||||
def __init__(self, dispatcher):
|
||||
self._dispatcher = dispatcher
|
||||
self._filters = []
|
||||
|
||||
@property
|
||||
def _default_filters(self):
|
||||
return tuple(filter(lambda item: item[-1], self._filters))
|
||||
|
||||
def bind(self, filter_, *args, default=False, with_dispatcher=False):
|
||||
self._filters.append((filter_, args, with_dispatcher, default))
|
||||
|
||||
def unbind(self, filter_):
|
||||
for item in self._filters:
|
||||
if filter_ is item[0]:
|
||||
self._filters.remove(item)
|
||||
return True
|
||||
raise ValueError(f'{filter_} is not binded.')
|
||||
|
||||
def replace(self, original, new):
|
||||
for item in self._filters:
|
||||
if original is item[0]:
|
||||
item[0] = new
|
||||
return True
|
||||
raise ValueError(f'{original} is not binded.')
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
"""
|
||||
Generate filters list
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
used = []
|
||||
filters = []
|
||||
|
||||
filters.extend(args)
|
||||
|
||||
# Registered filters filters
|
||||
for filter_, args_list, with_dispatcher, default in self._filters:
|
||||
config = {}
|
||||
accept = True
|
||||
|
||||
for item in args_list:
|
||||
value = kwargs.pop(item, None)
|
||||
if value is None:
|
||||
accept = False
|
||||
break
|
||||
config[item] = value
|
||||
|
||||
if accept:
|
||||
if with_dispatcher:
|
||||
config['dispatcher'] = self._dispatcher
|
||||
|
||||
filters.append(filter_(**config))
|
||||
used.append(filter_)
|
||||
|
||||
elif default:
|
||||
if filter_ not in used:
|
||||
used.append(filter_)
|
||||
if isinstance(filter_, Filter):
|
||||
if with_dispatcher:
|
||||
filters.append(filter_(dispatcher=self._dispatcher))
|
||||
else:
|
||||
filters.append(filter_())
|
||||
|
||||
# Not registered filters
|
||||
for key, filter_ in kwargs.items():
|
||||
if isinstance(filter_, Filter):
|
||||
filters.append(filter_)
|
||||
used.append(filter_.__class__)
|
||||
else:
|
||||
raise ValueError(f"Unknown filter with key '{key}'")
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
|
@ -48,6 +126,9 @@ class Filter:
|
|||
Base class for filters
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
|
|
@ -77,6 +158,7 @@ class AnyFilter(AsyncFilter):
|
|||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
|
|
@ -90,6 +172,7 @@ class NotFilter(AsyncFilter):
|
|||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
|
@ -102,6 +185,7 @@ class CommandsFilter(AsyncFilter):
|
|||
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
super().__init__()
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
|
|
@ -126,6 +210,7 @@ class RegexpFilter(Filter):
|
|||
|
||||
def __init__(self, regexp):
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
super().__init__()
|
||||
|
||||
def check(self, message):
|
||||
if message.text:
|
||||
|
|
@ -139,6 +224,7 @@ class RegexpCommandsFilter(AsyncFilter):
|
|||
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
super().__init__()
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
|
|
@ -165,10 +251,11 @@ class ContentTypeFilter(Filter):
|
|||
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
super().__init__()
|
||||
|
||||
def check(self, message):
|
||||
return ContentType.ANY[0] in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class CancelFilter(Filter):
|
||||
|
|
@ -180,6 +267,7 @@ class CancelFilter(Filter):
|
|||
if cancel_set is None:
|
||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||
self.cancel_set = cancel_set
|
||||
super().__init__()
|
||||
|
||||
def check(self, message):
|
||||
if message.text:
|
||||
|
|
@ -193,7 +281,10 @@ class StateFilter(AsyncFilter):
|
|||
|
||||
def __init__(self, dispatcher, state):
|
||||
self.dispatcher = dispatcher
|
||||
if isinstance(state, str):
|
||||
state = (state,)
|
||||
self.state = state
|
||||
super().__init__()
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
|
@ -209,7 +300,7 @@ class StateFilter(AsyncFilter):
|
|||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -233,6 +324,7 @@ class ExceptionsFilter(Filter):
|
|||
|
||||
def __init__(self, exception):
|
||||
self.exception = exception
|
||||
super().__init__()
|
||||
|
||||
def check(self, dispatcher, update, exception):
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue