diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 8f89d4bf..414eee81 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -5,8 +5,8 @@ import logging import time import typing -from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \ - USER_STATE, generate_default_filters +from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \ + generate_default_filters from .handler import CancelHandler, Handler, SkipHandler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -37,12 +37,15 @@ class Dispatcher: def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, run_tasks_by_default: bool = False, - throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False): + throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False, + filters_factory=None): if loop is None: loop = bot.loop if storage is None: storage = DisabledStorage() + if filters_factory is None: + filters_factory = FiltersFactory(self) self.bot: Bot = bot self.loop = loop @@ -54,6 +57,7 @@ class Dispatcher: self.last_update_id = 0 + self.filters_factory: FiltersFactory = filters_factory self.updates_handler = Handler(self, middleware_key='update') self.message_handlers = Handler(self, middleware_key='message') self.edited_message_handlers = Handler(self, middleware_key='edited_message') @@ -74,6 +78,12 @@ class Dispatcher: self._closed = True self._close_waiter = loop.create_future() + filters_factory.bind(filters.CommandsFilter, 'commands') + filters_factory.bind(filters.RegexpFilter, 'regexp') + filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands') + filters_factory.bind(filters.ContentTypeFilter, 'content_types') + filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True) + def __del__(self): self.stop_polling() @@ -310,8 +320,8 @@ class Dispatcher: """ return self._polling - def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None, - state=None, custom_filters=None, run_task=None, **kwargs): + def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None, + func=None, state=None, run_task=None, **kwargs): """ Register handler for message @@ -334,23 +344,23 @@ class Dispatcher: :param content_types: List of content types. :param func: custom any callable object :param custom_filters: list of custom filters + :param run_task: :param kwargs: :param state: :return: decorated function """ if content_types is None: content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] + if func is not None: + custom_filters = list(custom_filters) + custom_filters.append(func) - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) + filters_set = self.filters_factory.parse(*custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None, @@ -426,9 +436,9 @@ class Dispatcher: """ def decorator(callback): - self.register_message_handler(callback, + self.register_message_handler(callback, *custom_filters, commands=commands, regexp=regexp, content_types=content_types, - func=func, state=state, custom_filters=custom_filters, run_task=run_task, + func=func, state=state, run_task=run_task, **kwargs) return callback diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index 3b3b4d51..e277f217 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -9,6 +9,84 @@ from ..utils.helper import Helper, HelperMode, Item USER_STATE = 'USER_STATE' +class FiltersFactory: + def __init__(self, dispatcher): + self._dispatcher = dispatcher + self._filters = [] + + @property + def _default_filters(self): + return tuple(filter(lambda item: item[-1], self._filters)) + + def bind(self, filter_, *args, default=False, with_dispatcher=False): + self._filters.append((filter_, args, with_dispatcher, default)) + + def unbind(self, filter_): + for item in self._filters: + if filter_ is item[0]: + self._filters.remove(item) + return True + raise ValueError(f'{filter_} is not binded.') + + def replace(self, original, new): + for item in self._filters: + if original is item[0]: + item[0] = new + return True + raise ValueError(f'{original} is not binded.') + + def parse(self, *args, **kwargs): + """ + Generate filters list + + :param args: + :param kwargs: + :return: + """ + used = [] + filters = [] + + filters.extend(args) + + # Registered filters filters + for filter_, args_list, with_dispatcher, default in self._filters: + config = {} + accept = True + + for item in args_list: + value = kwargs.pop(item, None) + if value is None: + accept = False + break + config[item] = value + + if accept: + if with_dispatcher: + config['dispatcher'] = self._dispatcher + + filters.append(filter_(**config)) + used.append(filter_) + + elif default: + if filter_ not in used: + used.append(filter_) + if isinstance(filter_, Filter): + if with_dispatcher: + filters.append(filter_(dispatcher=self._dispatcher)) + else: + filters.append(filter_()) + + # Not registered filters + for key, filter_ in kwargs.items(): + if isinstance(filter_, Filter): + filters.append(filter_) + used.append(filter_.__class__) + else: + raise ValueError(f"Unknown filter with key '{key}'") + + return filters + + async def check_filter(filter_, args): """ Helper for executing filter @@ -48,6 +126,9 @@ class Filter: Base class for filters """ + def __init__(self, *args, **kwargs): + pass + def __call__(self, *args, **kwargs): return self.check(*args, **kwargs) @@ -77,6 +158,7 @@ class AnyFilter(AsyncFilter): def __init__(self, *filters: callable): self.filters = filters + super().__init__() async def check(self, *args): f = (check_filter(filter_, args) for filter_ in self.filters) @@ -90,6 +172,7 @@ class NotFilter(AsyncFilter): def __init__(self, filter_: callable): self.filter = filter_ + super().__init__() async def check(self, *args): return not await check_filter(self.filter, args) @@ -102,6 +185,7 @@ class CommandsFilter(AsyncFilter): def __init__(self, commands): self.commands = commands + super().__init__() async def check(self, message): if not message.is_command(): @@ -126,6 +210,7 @@ class RegexpFilter(Filter): def __init__(self, regexp): self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + super().__init__() def check(self, message): if message.text: @@ -139,6 +224,7 @@ class RegexpCommandsFilter(AsyncFilter): def __init__(self, regexp_commands): self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] + super().__init__() async def check(self, message): if not message.is_command(): @@ -165,10 +251,11 @@ class ContentTypeFilter(Filter): def __init__(self, content_types): self.content_types = content_types + super().__init__() def check(self, message): return ContentType.ANY[0] in self.content_types or \ - message.content_type in self.content_types + message.content_type in self.content_types class CancelFilter(Filter): @@ -180,6 +267,7 @@ class CancelFilter(Filter): if cancel_set is None: cancel_set = ['/cancel', 'cancel', 'cancel.'] self.cancel_set = cancel_set + super().__init__() def check(self, message): if message.text: @@ -193,7 +281,10 @@ class StateFilter(AsyncFilter): def __init__(self, dispatcher, state): self.dispatcher = dispatcher + if isinstance(state, str): + state = (state,) self.state = state + super().__init__() def get_target(self, obj): return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) @@ -209,7 +300,7 @@ class StateFilter(AsyncFilter): chat, user = self.get_target(obj) if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state return False @@ -233,6 +324,7 @@ class ExceptionsFilter(Filter): def __init__(self, exception): self.exception = exception + super().__init__() def check(self, dispatcher, update, exception): try: