The middlewares is back in the new interpretation + Small refactoring.

This commit is contained in:
Alex Root Junior 2017-12-10 02:36:16 +02:00
parent 4ee54ea4e6
commit 6a64573e44
4 changed files with 141 additions and 78 deletions

View file

@ -1,12 +1,14 @@
import asyncio
import functools
import itertools
import logging
import time
import typing
from aiogram.dispatcher.middlewares import MiddlewareManager
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
generate_default_filters
from .handler import Handler
from .handler import CancelHandler, Handler, SkipHandler
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
from .webhook import BaseResponse
from ..bot import Bot
@ -52,21 +54,22 @@ class Dispatcher:
self.last_update_id = 0
self.updates_handler = Handler(self)
self.message_handlers = Handler(self)
self.edited_message_handlers = Handler(self)
self.channel_post_handlers = Handler(self)
self.edited_channel_post_handlers = Handler(self)
self.inline_query_handlers = Handler(self)
self.chosen_inline_result_handlers = Handler(self)
self.callback_query_handlers = Handler(self)
self.shipping_query_handlers = Handler(self)
self.pre_checkout_query_handlers = Handler(self)
self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
self.errors_handlers = Handler(self, once=False, middleware_key='error')
self.middleware = MiddlewareManager(self)
self.updates_handler.register(self.process_update)
self.errors_handlers = Handler(self, once=False)
self._polling = False
def __del__(self):
@ -111,7 +114,7 @@ class Dispatcher:
"""
tasks = []
for update in updates:
tasks.append(self.process_update(update))
tasks.append(self.updates_handler.notify(update))
return await asyncio.gather(*tasks)
async def process_update(self, update):
@ -124,69 +127,58 @@ class Dispatcher:
start = time.time()
success = True
self.last_update_id = update.update_id
context.set_value(UPDATE_OBJECT, update)
try:
self.last_update_id = update.update_id
has_context = context.check_configured()
if has_context:
context.set_value(UPDATE_OBJECT, update)
if update.message:
if has_context:
state = await self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.update_state(chat=update.message.chat.id,
user=update.message.from_user.id,
state=state)
state = await self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.update_state(chat=update.message.chat.id,
user=update.message.from_user.id,
state=state)
return await self.message_handlers.notify(update.message)
if update.edited_message:
if has_context:
state = await self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id)
context.update_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id,
state=state)
state = await self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id)
context.update_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id,
state=state)
return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post:
if has_context:
state = await self.storage.get_state(chat=update.channel_post.chat.id)
context.update_state(chat=update.channel_post.chat.id,
state=state)
state = await self.storage.get_state(chat=update.channel_post.chat.id)
context.update_state(chat=update.channel_post.chat.id,
state=state)
return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post:
if has_context:
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
context.update_state(chat=update.edited_channel_post.chat.id,
state=state)
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
context.update_state(chat=update.edited_channel_post.chat.id,
state=state)
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query:
if has_context:
state = await self.storage.get_state(user=update.inline_query.from_user.id)
context.update_state(user=update.inline_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.inline_query.from_user.id)
context.update_state(user=update.inline_query.from_user.id,
state=state)
return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result:
if has_context:
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.update_state(user=update.chosen_inline_result.from_user.id,
state=state)
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.update_state(user=update.chosen_inline_result.from_user.id,
state=state)
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query:
if has_context:
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
user=update.callback_query.from_user.id)
context.update_state(user=update.callback_query.from_user.id,
state=state)
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
user=update.callback_query.from_user.id)
context.update_state(user=update.callback_query.from_user.id,
state=state)
return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query:
if has_context:
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
context.update_state(user=update.shipping_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
context.update_state(user=update.shipping_query.from_user.id,
state=state)
return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query:
if has_context:
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.update_state(user=update.pre_checkout_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.update_state(user=update.pre_checkout_query.from_user.id,
state=state)
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
except Exception as e:
success = False
@ -276,8 +268,8 @@ class Dispatcher:
:param updates: list of updates.
"""
need_to_call = []
for response in await self.process_updates(updates):
for response in response:
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
for response in responses:
if not isinstance(response, BaseResponse):
continue
need_to_call.append(response.execute_response(self.bot))
@ -903,7 +895,8 @@ class Dispatcher:
"""
def decorator(callback):
self.register_errors_handler(callback, func=func, exception=exception)
self.register_errors_handler(self._wrap_async_task(callback, run_task),
func=func, exception=exception)
return callback
return decorator
@ -948,7 +941,6 @@ class Dispatcher:
:return: bool
"""
if not self.storage.has_bucket():
print(self.storage)
raise RuntimeError('This storage does not provide Leaky Bucket')
if no_error is None:

View file

@ -9,7 +9,7 @@ from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args, kwargs):
async def check_filter(filter_, args):
"""
Helper for executing filter
@ -22,23 +22,22 @@ async def check_filter(filter_, args, kwargs):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args, **kwargs)
return await filter_(*args)
else:
return filter_(*args, **kwargs)
return filter_(*args)
async def check_filters(filters, args, kwargs):
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:param kwargs:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args, kwargs)
f = await check_filter(filter_, args)
if not f:
return False
return True
@ -76,8 +75,8 @@ class AnyFilter(AsyncFilter):
def __init__(self, *filters: callable):
self.filters = filters
async def check(self, *args, **kwargs):
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
@ -88,8 +87,8 @@ class NotFilter(AsyncFilter):
def __init__(self, filter_: callable):
self.filter = filter_
async def check(self, *args, **kwargs):
return not await check_filter(self.filter, args, kwargs)
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter):

View file

@ -10,11 +10,12 @@ class CancelHandler(BaseException):
class Handler:
def __init__(self, dispatcher, once=True):
def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher
self.once = once
self.handlers = []
self.middleware_key = middleware_key
def register(self, handler, filters=None, index=None):
"""
@ -48,20 +49,23 @@ class Handler:
return True
raise ValueError('This handler is not registered!')
async def notify(self, *args, **kwargs):
async def notify(self, *args):
"""
Notify handlers
:param args:
:param kwargs:
:return:
"""
results = []
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
for filters, handler in self.handlers:
if await check_filters(filters, args, kwargs):
if await check_filters(filters, args):
try:
response = await handler(*args, **kwargs)
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if results is not None:
results.append(response)
if self.once:
@ -70,5 +74,8 @@ class Handler:
continue
except CancelHandler:
break
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
return results

View file

@ -0,0 +1,65 @@
import logging
import typing
log = logging.getLogger('aiogram.Middleware')
class MiddlewareManager:
def __init__(self, dispatcher):
self.dispatcher = dispatcher
self.loop = dispatcher.loop
self.bot = dispatcher.bot
self.storage = dispatcher.storage
self.applications = []
def setup(self, middleware):
"""
Setup middleware
:param middleware:
:return:
"""
assert isinstance(middleware, BaseMiddleware)
if middleware.is_configured():
raise ValueError('That middleware is already used!')
self.applications.append(middleware)
middleware.setup(self)
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
async def trigger(self, action: str, args: typing.Iterable):
"""
Call action to middlewares with args lilt.
:param action:
:param args:
:return:
"""
for app in self.applications:
await app.trigger(action, args)
class BaseMiddleware:
def __init__(self):
self._configured = False
self._manager = None
@property
def manager(self) -> MiddlewareManager:
if self._manager is None:
raise RuntimeError('Middleware is not configured!')
return self._manager
def setup(self, manager):
self._manager = manager
self._configured = True
def is_configured(self):
return self._configured
async def trigger(self, action, args):
handler_name = f"on_{action}"
handler = getattr(self, handler_name, None)
if not handler:
return None
await handler(*args)