mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
The middlewares is back in the new interpretation + Small refactoring.
This commit is contained in:
parent
4ee54ea4e6
commit
6a64573e44
4 changed files with 141 additions and 78 deletions
|
|
@ -1,12 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
from aiogram.dispatcher.middlewares import MiddlewareManager
|
||||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||||
generate_default_filters
|
generate_default_filters
|
||||||
from .handler import Handler
|
from .handler import CancelHandler, Handler, SkipHandler
|
||||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||||
from .webhook import BaseResponse
|
from .webhook import BaseResponse
|
||||||
from ..bot import Bot
|
from ..bot import Bot
|
||||||
|
|
@ -52,21 +54,22 @@ class Dispatcher:
|
||||||
|
|
||||||
self.last_update_id = 0
|
self.last_update_id = 0
|
||||||
|
|
||||||
self.updates_handler = Handler(self)
|
self.updates_handler = Handler(self, middleware_key='update')
|
||||||
self.message_handlers = Handler(self)
|
self.message_handlers = Handler(self, middleware_key='message')
|
||||||
self.edited_message_handlers = Handler(self)
|
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||||
self.channel_post_handlers = Handler(self)
|
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
|
||||||
self.edited_channel_post_handlers = Handler(self)
|
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
|
||||||
self.inline_query_handlers = Handler(self)
|
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
|
||||||
self.chosen_inline_result_handlers = Handler(self)
|
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
|
||||||
self.callback_query_handlers = Handler(self)
|
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
|
||||||
self.shipping_query_handlers = Handler(self)
|
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
|
||||||
self.pre_checkout_query_handlers = Handler(self)
|
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
|
||||||
|
self.errors_handlers = Handler(self, once=False, middleware_key='error')
|
||||||
|
|
||||||
|
self.middleware = MiddlewareManager(self)
|
||||||
|
|
||||||
self.updates_handler.register(self.process_update)
|
self.updates_handler.register(self.process_update)
|
||||||
|
|
||||||
self.errors_handlers = Handler(self, once=False)
|
|
||||||
|
|
||||||
self._polling = False
|
self._polling = False
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
@ -111,7 +114,7 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
tasks = []
|
tasks = []
|
||||||
for update in updates:
|
for update in updates:
|
||||||
tasks.append(self.process_update(update))
|
tasks.append(self.updates_handler.notify(update))
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def process_update(self, update):
|
async def process_update(self, update):
|
||||||
|
|
@ -124,69 +127,58 @@ class Dispatcher:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
|
self.last_update_id = update.update_id
|
||||||
|
context.set_value(UPDATE_OBJECT, update)
|
||||||
try:
|
try:
|
||||||
self.last_update_id = update.update_id
|
|
||||||
has_context = context.check_configured()
|
|
||||||
if has_context:
|
|
||||||
context.set_value(UPDATE_OBJECT, update)
|
|
||||||
if update.message:
|
if update.message:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
user=update.message.from_user.id)
|
||||||
user=update.message.from_user.id)
|
context.update_state(chat=update.message.chat.id,
|
||||||
context.update_state(chat=update.message.chat.id,
|
user=update.message.from_user.id,
|
||||||
user=update.message.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.message_handlers.notify(update.message)
|
return await self.message_handlers.notify(update.message)
|
||||||
if update.edited_message:
|
if update.edited_message:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
user=update.edited_message.from_user.id)
|
||||||
user=update.edited_message.from_user.id)
|
context.update_state(chat=update.edited_message.chat.id,
|
||||||
context.update_state(chat=update.edited_message.chat.id,
|
user=update.edited_message.from_user.id,
|
||||||
user=update.edited_message.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.edited_message_handlers.notify(update.edited_message)
|
return await self.edited_message_handlers.notify(update.edited_message)
|
||||||
if update.channel_post:
|
if update.channel_post:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
context.update_state(chat=update.channel_post.chat.id,
|
||||||
context.update_state(chat=update.channel_post.chat.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.channel_post_handlers.notify(update.channel_post)
|
return await self.channel_post_handlers.notify(update.channel_post)
|
||||||
if update.edited_channel_post:
|
if update.edited_channel_post:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||||
if update.inline_query:
|
if update.inline_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
context.update_state(user=update.inline_query.from_user.id,
|
||||||
context.update_state(user=update.inline_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.inline_query_handlers.notify(update.inline_query)
|
return await self.inline_query_handlers.notify(update.inline_query)
|
||||||
if update.chosen_inline_result:
|
if update.chosen_inline_result:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
||||||
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
user=update.callback_query.from_user.id)
|
||||||
user=update.callback_query.from_user.id)
|
context.update_state(user=update.callback_query.from_user.id,
|
||||||
context.update_state(user=update.callback_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.callback_query_handlers.notify(update.callback_query)
|
return await self.callback_query_handlers.notify(update.callback_query)
|
||||||
if update.shipping_query:
|
if update.shipping_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
context.update_state(user=update.shipping_query.from_user.id,
|
||||||
context.update_state(user=update.shipping_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.shipping_query_handlers.notify(update.shipping_query)
|
return await self.shipping_query_handlers.notify(update.shipping_query)
|
||||||
if update.pre_checkout_query:
|
if update.pre_checkout_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
success = False
|
success = False
|
||||||
|
|
@ -276,8 +268,8 @@ class Dispatcher:
|
||||||
:param updates: list of updates.
|
:param updates: list of updates.
|
||||||
"""
|
"""
|
||||||
need_to_call = []
|
need_to_call = []
|
||||||
for response in await self.process_updates(updates):
|
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
|
||||||
for response in response:
|
for response in responses:
|
||||||
if not isinstance(response, BaseResponse):
|
if not isinstance(response, BaseResponse):
|
||||||
continue
|
continue
|
||||||
need_to_call.append(response.execute_response(self.bot))
|
need_to_call.append(response.execute_response(self.bot))
|
||||||
|
|
@ -903,7 +895,8 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_errors_handler(callback, func=func, exception=exception)
|
self.register_errors_handler(self._wrap_async_task(callback, run_task),
|
||||||
|
func=func, exception=exception)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -948,7 +941,6 @@ class Dispatcher:
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.storage.has_bucket():
|
if not self.storage.has_bucket():
|
||||||
print(self.storage)
|
|
||||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
if no_error is None:
|
if no_error is None:
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ..utils.helper import Helper, HelperMode, Item
|
||||||
USER_STATE = 'USER_STATE'
|
USER_STATE = 'USER_STATE'
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args, kwargs):
|
async def check_filter(filter_, args):
|
||||||
"""
|
"""
|
||||||
Helper for executing filter
|
Helper for executing filter
|
||||||
|
|
||||||
|
|
@ -22,23 +22,22 @@ async def check_filter(filter_, args, kwargs):
|
||||||
raise TypeError('Filter must be callable and/or awaitable!')
|
raise TypeError('Filter must be callable and/or awaitable!')
|
||||||
|
|
||||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||||
return await filter_(*args, **kwargs)
|
return await filter_(*args)
|
||||||
else:
|
else:
|
||||||
return filter_(*args, **kwargs)
|
return filter_(*args)
|
||||||
|
|
||||||
|
|
||||||
async def check_filters(filters, args, kwargs):
|
async def check_filters(filters, args):
|
||||||
"""
|
"""
|
||||||
Check list of filters
|
Check list of filters
|
||||||
|
|
||||||
:param filters:
|
:param filters:
|
||||||
:param args:
|
:param args:
|
||||||
:param kwargs:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
for filter_ in filters:
|
for filter_ in filters:
|
||||||
f = await check_filter(filter_, args, kwargs)
|
f = await check_filter(filter_, args)
|
||||||
if not f:
|
if not f:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
@ -76,8 +75,8 @@ class AnyFilter(AsyncFilter):
|
||||||
def __init__(self, *filters: callable):
|
def __init__(self, *filters: callable):
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
async def check(self, *args):
|
||||||
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
|
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||||
return any(await asyncio.gather(*f))
|
return any(await asyncio.gather(*f))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,8 +87,8 @@ class NotFilter(AsyncFilter):
|
||||||
def __init__(self, filter_: callable):
|
def __init__(self, filter_: callable):
|
||||||
self.filter = filter_
|
self.filter = filter_
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
async def check(self, *args):
|
||||||
return not await check_filter(self.filter, args, kwargs)
|
return not await check_filter(self.filter, args)
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(AsyncFilter):
|
class CommandsFilter(AsyncFilter):
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,12 @@ class CancelHandler(BaseException):
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, dispatcher, once=True):
|
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
self.once = once
|
self.once = once
|
||||||
|
|
||||||
self.handlers = []
|
self.handlers = []
|
||||||
|
self.middleware_key = middleware_key
|
||||||
|
|
||||||
def register(self, handler, filters=None, index=None):
|
def register(self, handler, filters=None, index=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,20 +49,23 @@ class Handler:
|
||||||
return True
|
return True
|
||||||
raise ValueError('This handler is not registered!')
|
raise ValueError('This handler is not registered!')
|
||||||
|
|
||||||
async def notify(self, *args, **kwargs):
|
async def notify(self, *args):
|
||||||
"""
|
"""
|
||||||
Notify handlers
|
Notify handlers
|
||||||
|
|
||||||
:param args:
|
:param args:
|
||||||
:param kwargs:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
||||||
for filters, handler in self.handlers:
|
for filters, handler in self.handlers:
|
||||||
if await check_filters(filters, args, kwargs):
|
if await check_filters(filters, args):
|
||||||
try:
|
try:
|
||||||
response = await handler(*args, **kwargs)
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||||
|
response = await handler(*args)
|
||||||
if results is not None:
|
if results is not None:
|
||||||
results.append(response)
|
results.append(response)
|
||||||
if self.once:
|
if self.once:
|
||||||
|
|
@ -70,5 +74,8 @@ class Handler:
|
||||||
continue
|
continue
|
||||||
except CancelHandler:
|
except CancelHandler:
|
||||||
break
|
break
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||||
|
args + (results,))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
65
aiogram/dispatcher/middlewares.py
Normal file
65
aiogram/dispatcher/middlewares.py
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
|
||||||
|
log = logging.getLogger('aiogram.Middleware')
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareManager:
|
||||||
|
def __init__(self, dispatcher):
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
self.loop = dispatcher.loop
|
||||||
|
self.bot = dispatcher.bot
|
||||||
|
self.storage = dispatcher.storage
|
||||||
|
self.applications = []
|
||||||
|
|
||||||
|
def setup(self, middleware):
|
||||||
|
"""
|
||||||
|
Setup middleware
|
||||||
|
|
||||||
|
:param middleware:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert isinstance(middleware, BaseMiddleware)
|
||||||
|
if middleware.is_configured():
|
||||||
|
raise ValueError('That middleware is already used!')
|
||||||
|
|
||||||
|
self.applications.append(middleware)
|
||||||
|
middleware.setup(self)
|
||||||
|
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
|
||||||
|
|
||||||
|
async def trigger(self, action: str, args: typing.Iterable):
|
||||||
|
"""
|
||||||
|
Call action to middlewares with args lilt.
|
||||||
|
|
||||||
|
:param action:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for app in self.applications:
|
||||||
|
await app.trigger(action, args)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMiddleware:
|
||||||
|
def __init__(self):
|
||||||
|
self._configured = False
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def manager(self) -> MiddlewareManager:
|
||||||
|
if self._manager is None:
|
||||||
|
raise RuntimeError('Middleware is not configured!')
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
def setup(self, manager):
|
||||||
|
self._manager = manager
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
def is_configured(self):
|
||||||
|
return self._configured
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
handler_name = f"on_{action}"
|
||||||
|
handler = getattr(self, handler_name, None)
|
||||||
|
if not handler:
|
||||||
|
return None
|
||||||
|
await handler(*args)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue