mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
The middlewares is back in the new interpretation + Small refactoring.
This commit is contained in:
parent
4ee54ea4e6
commit
6a64573e44
4 changed files with 141 additions and 78 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
import typing
|
||||
|
||||
from aiogram.dispatcher.middlewares import MiddlewareManager
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||
generate_default_filters
|
||||
from .handler import Handler
|
||||
from .handler import CancelHandler, Handler, SkipHandler
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||
from .webhook import BaseResponse
|
||||
from ..bot import Bot
|
||||
|
|
@ -52,21 +54,22 @@ class Dispatcher:
|
|||
|
||||
self.last_update_id = 0
|
||||
|
||||
self.updates_handler = Handler(self)
|
||||
self.message_handlers = Handler(self)
|
||||
self.edited_message_handlers = Handler(self)
|
||||
self.channel_post_handlers = Handler(self)
|
||||
self.edited_channel_post_handlers = Handler(self)
|
||||
self.inline_query_handlers = Handler(self)
|
||||
self.chosen_inline_result_handlers = Handler(self)
|
||||
self.callback_query_handlers = Handler(self)
|
||||
self.shipping_query_handlers = Handler(self)
|
||||
self.pre_checkout_query_handlers = Handler(self)
|
||||
self.updates_handler = Handler(self, middleware_key='update')
|
||||
self.message_handlers = Handler(self, middleware_key='message')
|
||||
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
|
||||
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
|
||||
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
|
||||
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
|
||||
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
|
||||
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
|
||||
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
|
||||
self.errors_handlers = Handler(self, once=False, middleware_key='error')
|
||||
|
||||
self.middleware = MiddlewareManager(self)
|
||||
|
||||
self.updates_handler.register(self.process_update)
|
||||
|
||||
self.errors_handlers = Handler(self, once=False)
|
||||
|
||||
self._polling = False
|
||||
|
||||
def __del__(self):
|
||||
|
|
@ -111,7 +114,7 @@ class Dispatcher:
|
|||
"""
|
||||
tasks = []
|
||||
for update in updates:
|
||||
tasks.append(self.process_update(update))
|
||||
tasks.append(self.updates_handler.notify(update))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def process_update(self, update):
|
||||
|
|
@ -124,69 +127,58 @@ class Dispatcher:
|
|||
start = time.time()
|
||||
success = True
|
||||
|
||||
self.last_update_id = update.update_id
|
||||
context.set_value(UPDATE_OBJECT, update)
|
||||
try:
|
||||
self.last_update_id = update.update_id
|
||||
has_context = context.check_configured()
|
||||
if has_context:
|
||||
context.set_value(UPDATE_OBJECT, update)
|
||||
if update.message:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.update_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.update_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id,
|
||||
state=state)
|
||||
return await self.message_handlers.notify(update.message)
|
||||
if update.edited_message:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id)
|
||||
context.update_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id)
|
||||
context.update_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id,
|
||||
state=state)
|
||||
return await self.edited_message_handlers.notify(update.edited_message)
|
||||
if update.channel_post:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||
context.update_state(chat=update.channel_post.chat.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||
context.update_state(chat=update.channel_post.chat.id,
|
||||
state=state)
|
||||
return await self.channel_post_handlers.notify(update.channel_post)
|
||||
if update.edited_channel_post:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||
state=state)
|
||||
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||
if update.inline_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||
context.update_state(user=update.inline_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||
context.update_state(user=update.inline_query.from_user.id,
|
||||
state=state)
|
||||
return await self.inline_query_handlers.notify(update.inline_query)
|
||||
if update.chosen_inline_result:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||
state=state)
|
||||
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||
if update.callback_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
||||
user=update.callback_query.from_user.id)
|
||||
context.update_state(user=update.callback_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
||||
user=update.callback_query.from_user.id)
|
||||
context.update_state(user=update.callback_query.from_user.id,
|
||||
state=state)
|
||||
return await self.callback_query_handlers.notify(update.callback_query)
|
||||
if update.shipping_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||
context.update_state(user=update.shipping_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||
context.update_state(user=update.shipping_query.from_user.id,
|
||||
state=state)
|
||||
return await self.shipping_query_handlers.notify(update.shipping_query)
|
||||
if update.pre_checkout_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||
state=state)
|
||||
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||
except Exception as e:
|
||||
success = False
|
||||
|
|
@ -276,8 +268,8 @@ class Dispatcher:
|
|||
:param updates: list of updates.
|
||||
"""
|
||||
need_to_call = []
|
||||
for response in await self.process_updates(updates):
|
||||
for response in response:
|
||||
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
|
||||
for response in responses:
|
||||
if not isinstance(response, BaseResponse):
|
||||
continue
|
||||
need_to_call.append(response.execute_response(self.bot))
|
||||
|
|
@ -903,7 +895,8 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def decorator(callback):
|
||||
self.register_errors_handler(callback, func=func, exception=exception)
|
||||
self.register_errors_handler(self._wrap_async_task(callback, run_task),
|
||||
func=func, exception=exception)
|
||||
return callback
|
||||
|
||||
return decorator
|
||||
|
|
@ -948,7 +941,6 @@ class Dispatcher:
|
|||
:return: bool
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
print(self.storage)
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if no_error is None:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from ..utils.helper import Helper, HelperMode, Item
|
|||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
async def check_filter(filter_, args, kwargs):
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
|
|
@ -22,23 +22,22 @@ async def check_filter(filter_, args, kwargs):
|
|||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||
return await filter_(*args, **kwargs)
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args, **kwargs)
|
||||
return filter_(*args)
|
||||
|
||||
|
||||
async def check_filters(filters, args, kwargs):
|
||||
async def check_filters(filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param filters:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args, kwargs)
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
|
|
@ -76,8 +75,8 @@ class AnyFilter(AsyncFilter):
|
|||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
|
|
@ -88,8 +87,8 @@ class NotFilter(AsyncFilter):
|
|||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
return not await check_filter(self.filter, args, kwargs)
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(AsyncFilter):
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ class CancelHandler(BaseException):
|
|||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, dispatcher, once=True):
|
||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||
self.dispatcher = dispatcher
|
||||
self.once = once
|
||||
|
||||
self.handlers = []
|
||||
self.middleware_key = middleware_key
|
||||
|
||||
def register(self, handler, filters=None, index=None):
|
||||
"""
|
||||
|
|
@ -48,20 +49,23 @@ class Handler:
|
|||
return True
|
||||
raise ValueError('This handler is not registered!')
|
||||
|
||||
async def notify(self, *args, **kwargs):
|
||||
async def notify(self, *args):
|
||||
"""
|
||||
Notify handlers
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
results = []
|
||||
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
||||
for filters, handler in self.handlers:
|
||||
if await check_filters(filters, args, kwargs):
|
||||
if await check_filters(filters, args):
|
||||
try:
|
||||
response = await handler(*args, **kwargs)
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
response = await handler(*args)
|
||||
if results is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
|
|
@ -70,5 +74,8 @@ class Handler:
|
|||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results,))
|
||||
|
||||
return results
|
||||
|
|
|
|||
65
aiogram/dispatcher/middlewares.py
Normal file
65
aiogram/dispatcher/middlewares.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
log = logging.getLogger('aiogram.Middleware')
|
||||
|
||||
|
||||
class MiddlewareManager:
|
||||
def __init__(self, dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
self.loop = dispatcher.loop
|
||||
self.bot = dispatcher.bot
|
||||
self.storage = dispatcher.storage
|
||||
self.applications = []
|
||||
|
||||
def setup(self, middleware):
|
||||
"""
|
||||
Setup middleware
|
||||
|
||||
:param middleware:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(middleware, BaseMiddleware)
|
||||
if middleware.is_configured():
|
||||
raise ValueError('That middleware is already used!')
|
||||
|
||||
self.applications.append(middleware)
|
||||
middleware.setup(self)
|
||||
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
|
||||
|
||||
async def trigger(self, action: str, args: typing.Iterable):
|
||||
"""
|
||||
Call action to middlewares with args lilt.
|
||||
|
||||
:param action:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
for app in self.applications:
|
||||
await app.trigger(action, args)
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._manager = None
|
||||
|
||||
@property
|
||||
def manager(self) -> MiddlewareManager:
|
||||
if self._manager is None:
|
||||
raise RuntimeError('Middleware is not configured!')
|
||||
return self._manager
|
||||
|
||||
def setup(self, manager):
|
||||
self._manager = manager
|
||||
self._configured = True
|
||||
|
||||
def is_configured(self):
|
||||
return self._configured
|
||||
|
||||
async def trigger(self, action, args):
|
||||
handler_name = f"on_{action}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
if not handler:
|
||||
return None
|
||||
await handler(*args)
|
||||
Loading…
Add table
Add a link
Reference in a new issue