The middlewares is back in the new interpretation + Small refactoring.

This commit is contained in:
Alex Root Junior 2017-12-10 02:36:16 +02:00
parent 4ee54ea4e6
commit 6a64573e44
4 changed files with 141 additions and 78 deletions

View file

@ -1,12 +1,14 @@
import asyncio import asyncio
import functools import functools
import itertools
import logging import logging
import time import time
import typing import typing
from aiogram.dispatcher.middlewares import MiddlewareManager
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \ from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
generate_default_filters generate_default_filters
from .handler import Handler from .handler import CancelHandler, Handler, SkipHandler
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
from .webhook import BaseResponse from .webhook import BaseResponse
from ..bot import Bot from ..bot import Bot
@ -52,21 +54,22 @@ class Dispatcher:
self.last_update_id = 0 self.last_update_id = 0
self.updates_handler = Handler(self) self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self) self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self) self.edited_message_handlers = Handler(self, middleware_key='edited_message')
self.channel_post_handlers = Handler(self) self.channel_post_handlers = Handler(self, middleware_key='channel_post')
self.edited_channel_post_handlers = Handler(self) self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
self.inline_query_handlers = Handler(self) self.inline_query_handlers = Handler(self, middleware_key='inline_query')
self.chosen_inline_result_handlers = Handler(self) self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
self.callback_query_handlers = Handler(self) self.callback_query_handlers = Handler(self, middleware_key='callback_query')
self.shipping_query_handlers = Handler(self) self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
self.pre_checkout_query_handlers = Handler(self) self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
self.errors_handlers = Handler(self, once=False, middleware_key='error')
self.middleware = MiddlewareManager(self)
self.updates_handler.register(self.process_update) self.updates_handler.register(self.process_update)
self.errors_handlers = Handler(self, once=False)
self._polling = False self._polling = False
def __del__(self): def __del__(self):
@ -111,7 +114,7 @@ class Dispatcher:
""" """
tasks = [] tasks = []
for update in updates: for update in updates:
tasks.append(self.process_update(update)) tasks.append(self.updates_handler.notify(update))
return await asyncio.gather(*tasks) return await asyncio.gather(*tasks)
async def process_update(self, update): async def process_update(self, update):
@ -124,69 +127,58 @@ class Dispatcher:
start = time.time() start = time.time()
success = True success = True
self.last_update_id = update.update_id
context.set_value(UPDATE_OBJECT, update)
try: try:
self.last_update_id = update.update_id
has_context = context.check_configured()
if has_context:
context.set_value(UPDATE_OBJECT, update)
if update.message: if update.message:
if has_context: state = await self.storage.get_state(chat=update.message.chat.id,
state = await self.storage.get_state(chat=update.message.chat.id, user=update.message.from_user.id)
user=update.message.from_user.id) context.update_state(chat=update.message.chat.id,
context.update_state(chat=update.message.chat.id, user=update.message.from_user.id,
user=update.message.from_user.id, state=state)
state=state)
return await self.message_handlers.notify(update.message) return await self.message_handlers.notify(update.message)
if update.edited_message: if update.edited_message:
if has_context: state = await self.storage.get_state(chat=update.edited_message.chat.id,
state = await self.storage.get_state(chat=update.edited_message.chat.id, user=update.edited_message.from_user.id)
user=update.edited_message.from_user.id) context.update_state(chat=update.edited_message.chat.id,
context.update_state(chat=update.edited_message.chat.id, user=update.edited_message.from_user.id,
user=update.edited_message.from_user.id, state=state)
state=state)
return await self.edited_message_handlers.notify(update.edited_message) return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post: if update.channel_post:
if has_context: state = await self.storage.get_state(chat=update.channel_post.chat.id)
state = await self.storage.get_state(chat=update.channel_post.chat.id) context.update_state(chat=update.channel_post.chat.id,
context.update_state(chat=update.channel_post.chat.id, state=state)
state=state)
return await self.channel_post_handlers.notify(update.channel_post) return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post: if update.edited_channel_post:
if has_context: state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) context.update_state(chat=update.edited_channel_post.chat.id,
context.update_state(chat=update.edited_channel_post.chat.id, state=state)
state=state)
return await self.edited_channel_post_handlers.notify(update.edited_channel_post) return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query: if update.inline_query:
if has_context: state = await self.storage.get_state(user=update.inline_query.from_user.id)
state = await self.storage.get_state(user=update.inline_query.from_user.id) context.update_state(user=update.inline_query.from_user.id,
context.update_state(user=update.inline_query.from_user.id, state=state)
state=state)
return await self.inline_query_handlers.notify(update.inline_query) return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result: if update.chosen_inline_result:
if has_context: state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) context.update_state(user=update.chosen_inline_result.from_user.id,
context.update_state(user=update.chosen_inline_result.from_user.id, state=state)
state=state)
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query: if update.callback_query:
if has_context: state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
state = await self.storage.get_state(chat=update.callback_query.message.chat.id, user=update.callback_query.from_user.id)
user=update.callback_query.from_user.id) context.update_state(user=update.callback_query.from_user.id,
context.update_state(user=update.callback_query.from_user.id, state=state)
state=state)
return await self.callback_query_handlers.notify(update.callback_query) return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query: if update.shipping_query:
if has_context: state = await self.storage.get_state(user=update.shipping_query.from_user.id)
state = await self.storage.get_state(user=update.shipping_query.from_user.id) context.update_state(user=update.shipping_query.from_user.id,
context.update_state(user=update.shipping_query.from_user.id, state=state)
state=state)
return await self.shipping_query_handlers.notify(update.shipping_query) return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query: if update.pre_checkout_query:
if has_context: state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) context.update_state(user=update.pre_checkout_query.from_user.id,
context.update_state(user=update.pre_checkout_query.from_user.id, state=state)
state=state)
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
except Exception as e: except Exception as e:
success = False success = False
@ -276,8 +268,8 @@ class Dispatcher:
:param updates: list of updates. :param updates: list of updates.
""" """
need_to_call = [] need_to_call = []
for response in await self.process_updates(updates): for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
for response in response: for response in responses:
if not isinstance(response, BaseResponse): if not isinstance(response, BaseResponse):
continue continue
need_to_call.append(response.execute_response(self.bot)) need_to_call.append(response.execute_response(self.bot))
@ -903,7 +895,8 @@ class Dispatcher:
""" """
def decorator(callback): def decorator(callback):
self.register_errors_handler(callback, func=func, exception=exception) self.register_errors_handler(self._wrap_async_task(callback, run_task),
func=func, exception=exception)
return callback return callback
return decorator return decorator
@ -948,7 +941,6 @@ class Dispatcher:
:return: bool :return: bool
""" """
if not self.storage.has_bucket(): if not self.storage.has_bucket():
print(self.storage)
raise RuntimeError('This storage does not provide Leaky Bucket') raise RuntimeError('This storage does not provide Leaky Bucket')
if no_error is None: if no_error is None:

View file

@ -9,7 +9,7 @@ from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE' USER_STATE = 'USER_STATE'
async def check_filter(filter_, args, kwargs): async def check_filter(filter_, args):
""" """
Helper for executing filter Helper for executing filter
@ -22,23 +22,22 @@ async def check_filter(filter_, args, kwargs):
raise TypeError('Filter must be callable and/or awaitable!') raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args, **kwargs) return await filter_(*args)
else: else:
return filter_(*args, **kwargs) return filter_(*args)
async def check_filters(filters, args, kwargs): async def check_filters(filters, args):
""" """
Check list of filters Check list of filters
:param filters: :param filters:
:param args: :param args:
:param kwargs:
:return: :return:
""" """
if filters is not None: if filters is not None:
for filter_ in filters: for filter_ in filters:
f = await check_filter(filter_, args, kwargs) f = await check_filter(filter_, args)
if not f: if not f:
return False return False
return True return True
@ -76,8 +75,8 @@ class AnyFilter(AsyncFilter):
def __init__(self, *filters: callable): def __init__(self, *filters: callable):
self.filters = filters self.filters = filters
async def check(self, *args, **kwargs): async def check(self, *args):
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters) f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f)) return any(await asyncio.gather(*f))
@ -88,8 +87,8 @@ class NotFilter(AsyncFilter):
def __init__(self, filter_: callable): def __init__(self, filter_: callable):
self.filter = filter_ self.filter = filter_
async def check(self, *args, **kwargs): async def check(self, *args):
return not await check_filter(self.filter, args, kwargs) return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter): class CommandsFilter(AsyncFilter):

View file

@ -10,11 +10,12 @@ class CancelHandler(BaseException):
class Handler: class Handler:
def __init__(self, dispatcher, once=True): def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher self.dispatcher = dispatcher
self.once = once self.once = once
self.handlers = [] self.handlers = []
self.middleware_key = middleware_key
def register(self, handler, filters=None, index=None): def register(self, handler, filters=None, index=None):
""" """
@ -48,20 +49,23 @@ class Handler:
return True return True
raise ValueError('This handler is not registered!') raise ValueError('This handler is not registered!')
async def notify(self, *args, **kwargs): async def notify(self, *args):
""" """
Notify handlers Notify handlers
:param args: :param args:
:param kwargs:
:return: :return:
""" """
results = [] results = []
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
for filters, handler in self.handlers: for filters, handler in self.handlers:
if await check_filters(filters, args, kwargs): if await check_filters(filters, args):
try: try:
response = await handler(*args, **kwargs) if self.middleware_key:
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if results is not None: if results is not None:
results.append(response) results.append(response)
if self.once: if self.once:
@ -70,5 +74,8 @@ class Handler:
continue continue
except CancelHandler: except CancelHandler:
break break
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
return results return results

View file

@ -0,0 +1,65 @@
import logging
import typing
log = logging.getLogger('aiogram.Middleware')
class MiddlewareManager:
def __init__(self, dispatcher):
self.dispatcher = dispatcher
self.loop = dispatcher.loop
self.bot = dispatcher.bot
self.storage = dispatcher.storage
self.applications = []
def setup(self, middleware):
"""
Setup middleware
:param middleware:
:return:
"""
assert isinstance(middleware, BaseMiddleware)
if middleware.is_configured():
raise ValueError('That middleware is already used!')
self.applications.append(middleware)
middleware.setup(self)
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
async def trigger(self, action: str, args: typing.Iterable):
"""
Call action to middlewares with args lilt.
:param action:
:param args:
:return:
"""
for app in self.applications:
await app.trigger(action, args)
class BaseMiddleware:
def __init__(self):
self._configured = False
self._manager = None
@property
def manager(self) -> MiddlewareManager:
if self._manager is None:
raise RuntimeError('Middleware is not configured!')
return self._manager
def setup(self, manager):
self._manager = manager
self._configured = True
def is_configured(self):
return self._configured
async def trigger(self, action, args):
handler_name = f"on_{action}"
handler = getattr(self, handler_name, None)
if not handler:
return None
await handler(*args)