From 6a64573e448e567402e19c9507c4cb492d52df42 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 10 Dec 2017 02:36:16 +0200 Subject: [PATCH] The middlewares is back in the new interpretation + Small refactoring. --- aiogram/dispatcher/__init__.py | 118 ++++++++++++++---------------- aiogram/dispatcher/filters.py | 19 +++-- aiogram/dispatcher/handler.py | 17 +++-- aiogram/dispatcher/middlewares.py | 65 ++++++++++++++++ 4 files changed, 141 insertions(+), 78 deletions(-) create mode 100644 aiogram/dispatcher/middlewares.py diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index b4179421..258dcb95 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,12 +1,14 @@ import asyncio import functools +import itertools import logging import time import typing +from aiogram.dispatcher.middlewares import MiddlewareManager from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \ generate_default_filters -from .handler import Handler +from .handler import CancelHandler, Handler, SkipHandler from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT from .webhook import BaseResponse from ..bot import Bot @@ -52,21 +54,22 @@ class Dispatcher: self.last_update_id = 0 - self.updates_handler = Handler(self) - self.message_handlers = Handler(self) - self.edited_message_handlers = Handler(self) - self.channel_post_handlers = Handler(self) - self.edited_channel_post_handlers = Handler(self) - self.inline_query_handlers = Handler(self) - self.chosen_inline_result_handlers = Handler(self) - self.callback_query_handlers = Handler(self) - self.shipping_query_handlers = Handler(self) - self.pre_checkout_query_handlers = Handler(self) + self.updates_handler = Handler(self, middleware_key='update') + self.message_handlers = Handler(self, middleware_key='message') + self.edited_message_handlers = Handler(self, middleware_key='edited_message') + self.channel_post_handlers = Handler(self, middleware_key='channel_post') + self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post') + self.inline_query_handlers = Handler(self, middleware_key='inline_query') + self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result') + self.callback_query_handlers = Handler(self, middleware_key='callback_query') + self.shipping_query_handlers = Handler(self, middleware_key='shipping_query') + self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query') + self.errors_handlers = Handler(self, once=False, middleware_key='error') + + self.middleware = MiddlewareManager(self) self.updates_handler.register(self.process_update) - self.errors_handlers = Handler(self, once=False) - self._polling = False def __del__(self): @@ -111,7 +114,7 @@ class Dispatcher: """ tasks = [] for update in updates: - tasks.append(self.process_update(update)) + tasks.append(self.updates_handler.notify(update)) return await asyncio.gather(*tasks) async def process_update(self, update): @@ -124,69 +127,58 @@ class Dispatcher: start = time.time() success = True + self.last_update_id = update.update_id + context.set_value(UPDATE_OBJECT, update) try: - self.last_update_id = update.update_id - has_context = context.check_configured() - if has_context: - context.set_value(UPDATE_OBJECT, update) if update.message: - if has_context: - state = await self.storage.get_state(chat=update.message.chat.id, - user=update.message.from_user.id) - context.update_state(chat=update.message.chat.id, - user=update.message.from_user.id, - state=state) + state = await self.storage.get_state(chat=update.message.chat.id, + user=update.message.from_user.id) + context.update_state(chat=update.message.chat.id, + user=update.message.from_user.id, + state=state) return await self.message_handlers.notify(update.message) if update.edited_message: - if has_context: - state = await self.storage.get_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id) - context.update_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id, - state=state) + state = await self.storage.get_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id) + context.update_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id, + state=state) return await self.edited_message_handlers.notify(update.edited_message) if update.channel_post: - if has_context: - state = await self.storage.get_state(chat=update.channel_post.chat.id) - context.update_state(chat=update.channel_post.chat.id, - state=state) + state = await self.storage.get_state(chat=update.channel_post.chat.id) + context.update_state(chat=update.channel_post.chat.id, + state=state) return await self.channel_post_handlers.notify(update.channel_post) if update.edited_channel_post: - if has_context: - state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) - context.update_state(chat=update.edited_channel_post.chat.id, - state=state) + state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) + context.update_state(chat=update.edited_channel_post.chat.id, + state=state) return await self.edited_channel_post_handlers.notify(update.edited_channel_post) if update.inline_query: - if has_context: - state = await self.storage.get_state(user=update.inline_query.from_user.id) - context.update_state(user=update.inline_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.inline_query.from_user.id) + context.update_state(user=update.inline_query.from_user.id, + state=state) return await self.inline_query_handlers.notify(update.inline_query) if update.chosen_inline_result: - if has_context: - state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) - context.update_state(user=update.chosen_inline_result.from_user.id, - state=state) + state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) + context.update_state(user=update.chosen_inline_result.from_user.id, + state=state) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) if update.callback_query: - if has_context: - state = await self.storage.get_state(chat=update.callback_query.message.chat.id, - user=update.callback_query.from_user.id) - context.update_state(user=update.callback_query.from_user.id, - state=state) + state = await self.storage.get_state(chat=update.callback_query.message.chat.id, + user=update.callback_query.from_user.id) + context.update_state(user=update.callback_query.from_user.id, + state=state) return await self.callback_query_handlers.notify(update.callback_query) if update.shipping_query: - if has_context: - state = await self.storage.get_state(user=update.shipping_query.from_user.id) - context.update_state(user=update.shipping_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.shipping_query.from_user.id) + context.update_state(user=update.shipping_query.from_user.id, + state=state) return await self.shipping_query_handlers.notify(update.shipping_query) if update.pre_checkout_query: - if has_context: - state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) - context.update_state(user=update.pre_checkout_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) + context.update_state(user=update.pre_checkout_query.from_user.id, + state=state) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) except Exception as e: success = False @@ -276,8 +268,8 @@ class Dispatcher: :param updates: list of updates. """ need_to_call = [] - for response in await self.process_updates(updates): - for response in response: + for responses in itertools.chain.from_iterable(await self.process_updates(updates)): + for response in responses: if not isinstance(response, BaseResponse): continue need_to_call.append(response.execute_response(self.bot)) @@ -903,7 +895,8 @@ class Dispatcher: """ def decorator(callback): - self.register_errors_handler(callback, func=func, exception=exception) + self.register_errors_handler(self._wrap_async_task(callback, run_task), + func=func, exception=exception) return callback return decorator @@ -948,7 +941,6 @@ class Dispatcher: :return: bool """ if not self.storage.has_bucket(): - print(self.storage) raise RuntimeError('This storage does not provide Leaky Bucket') if no_error is None: diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index 5cb3392c..175352e4 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -9,7 +9,7 @@ from ..utils.helper import Helper, HelperMode, Item USER_STATE = 'USER_STATE' -async def check_filter(filter_, args, kwargs): +async def check_filter(filter_, args): """ Helper for executing filter @@ -22,23 +22,22 @@ async def check_filter(filter_, args, kwargs): raise TypeError('Filter must be callable and/or awaitable!') if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): - return await filter_(*args, **kwargs) + return await filter_(*args) else: - return filter_(*args, **kwargs) + return filter_(*args) -async def check_filters(filters, args, kwargs): +async def check_filters(filters, args): """ Check list of filters :param filters: :param args: - :param kwargs: :return: """ if filters is not None: for filter_ in filters: - f = await check_filter(filter_, args, kwargs) + f = await check_filter(filter_, args) if not f: return False return True @@ -76,8 +75,8 @@ class AnyFilter(AsyncFilter): def __init__(self, *filters: callable): self.filters = filters - async def check(self, *args, **kwargs): - f = (check_filter(filter_, args, kwargs) for filter_ in self.filters) + async def check(self, *args): + f = (check_filter(filter_, args) for filter_ in self.filters) return any(await asyncio.gather(*f)) @@ -88,8 +87,8 @@ class NotFilter(AsyncFilter): def __init__(self, filter_: callable): self.filter = filter_ - async def check(self, *args, **kwargs): - return not await check_filter(self.filter, args, kwargs) + async def check(self, *args): + return not await check_filter(self.filter, args) class CommandsFilter(AsyncFilter): diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 517b8b75..157fb38e 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -10,11 +10,12 @@ class CancelHandler(BaseException): class Handler: - def __init__(self, dispatcher, once=True): + def __init__(self, dispatcher, once=True, middleware_key=None): self.dispatcher = dispatcher self.once = once self.handlers = [] + self.middleware_key = middleware_key def register(self, handler, filters=None, index=None): """ @@ -48,20 +49,23 @@ class Handler: return True raise ValueError('This handler is not registered!') - async def notify(self, *args, **kwargs): + async def notify(self, *args): """ Notify handlers :param args: - :param kwargs: :return: """ results = [] + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args) for filters, handler in self.handlers: - if await check_filters(filters, args, kwargs): + if await check_filters(filters, args): try: - response = await handler(*args, **kwargs) + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) + response = await handler(*args) if results is not None: results.append(response) if self.once: @@ -70,5 +74,8 @@ class Handler: continue except CancelHandler: break + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", + args + (results,)) return results diff --git a/aiogram/dispatcher/middlewares.py b/aiogram/dispatcher/middlewares.py new file mode 100644 index 00000000..d86f5453 --- /dev/null +++ b/aiogram/dispatcher/middlewares.py @@ -0,0 +1,65 @@ +import logging +import typing + +log = logging.getLogger('aiogram.Middleware') + + +class MiddlewareManager: + def __init__(self, dispatcher): + self.dispatcher = dispatcher + self.loop = dispatcher.loop + self.bot = dispatcher.bot + self.storage = dispatcher.storage + self.applications = [] + + def setup(self, middleware): + """ + Setup middleware + + :param middleware: + :return: + """ + assert isinstance(middleware, BaseMiddleware) + if middleware.is_configured(): + raise ValueError('That middleware is already used!') + + self.applications.append(middleware) + middleware.setup(self) + log.debug(f"Loaded middleware '{middleware.__class__.__name__}'") + + async def trigger(self, action: str, args: typing.Iterable): + """ + Call action to middlewares with args lilt. + + :param action: + :param args: + :return: + """ + for app in self.applications: + await app.trigger(action, args) + + +class BaseMiddleware: + def __init__(self): + self._configured = False + self._manager = None + + @property + def manager(self) -> MiddlewareManager: + if self._manager is None: + raise RuntimeError('Middleware is not configured!') + return self._manager + + def setup(self, manager): + self._manager = manager + self._configured = True + + def is_configured(self): + return self._configured + + async def trigger(self, action, args): + handler_name = f"on_{action}" + handler = getattr(self, handler_name, None) + if not handler: + return None + await handler(*args)