diff --git a/aiogram/__init__.py b/aiogram/__init__.py index 8975399c..816a1f57 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -3,6 +3,7 @@ from .api.client import session from .api.client.bot import Bot from .dispatcher import filters, handler from .dispatcher.dispatcher import Dispatcher +from .dispatcher.middlewares.base import BaseMiddleware from .dispatcher.router import Router try: @@ -22,6 +23,7 @@ __all__ = ( "session", "Dispatcher", "Router", + "BaseMiddleware", "filters", "handler", ) diff --git a/aiogram/api/types/message.py b/aiogram/api/types/message.py index 19d2b036..32894498 100644 --- a/aiogram/api/types/message.py +++ b/aiogram/api/types/message.py @@ -240,6 +240,8 @@ class Message(TelegramObject): return ContentType.PASSPORT_DATA if self.poll: return ContentType.POLL + if self.dice: + return ContentType.DICE return ContentType.UNKNOWN diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 93f4aac6..756d57f2 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -1,21 +1,12 @@ from __future__ import annotations from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Callable, - Dict, - Generator, - List, - Optional, - Type, -) +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, Generator, List, Type from pydantic import ValidationError from ..filters.base import BaseFilter +from ..middlewares.types import MiddlewareStep, UpdateType from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType if TYPE_CHECKING: # pragma: no cover @@ -95,10 +86,8 @@ class TelegramEventObserver(EventObserver): """ registry: List[Type[BaseFilter]] = [] - router: Optional[Router] = self.router - while router: + for router in self.router.chain: observer = router.observers[self.event_name] - router = router.parent_router for filter_ in observer.filters: if filter_ in registry: @@ -133,6 +122,37 @@ class TelegramEventObserver(EventObserver): return filters + async def trigger_middleware( + self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None, + ) -> None: + """ + Trigger middlewares chain + + :param step: + :param event: + :param data: + :param result: + :return: + """ + reverse = step == MiddlewareStep.POST_PROCESS + recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS + + if self.event_name == "update": + routers = self.router.chain + else: + routers = self.router.chain_head + for router in routers: + await router.middleware.trigger( + step=step, + event_name=self.event_name, + event=event, + data=data, + result=result, + reverse=reverse, + ) + if not recursive: + break + def register( self, callback: HandlerType, *filters: FilterType, **bound_filters: Any ) -> HandlerType: @@ -153,12 +173,24 @@ class TelegramEventObserver(EventObserver): Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ + event = args[0] + await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs) for handler in self.handlers: result, data = await handler.check(*args, **kwargs) if result: kwargs.update(data) + await self.trigger_middleware( + step=MiddlewareStep.PROCESS, event=event, data=kwargs + ) try: - yield await handler.call(*args, **kwargs) + response = await handler.call(*args, **kwargs) + await self.trigger_middleware( + step=MiddlewareStep.POST_PROCESS, + event=event, + data=kwargs, + result=response, + ) + yield response except SkipHandler: continue break diff --git a/aiogram/dispatcher/middlewares/__init__.py b/aiogram/dispatcher/middlewares/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/dispatcher/middlewares/abstract.py b/aiogram/dispatcher/middlewares/abstract.py new file mode 100644 index 00000000..eac16534 --- /dev/null +++ b/aiogram/dispatcher/middlewares/abstract.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Optional + +from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType + +if TYPE_CHECKING: # pragma: no cover + from aiogram.dispatcher.middlewares.manager import MiddlewareManager + + +class AbstractMiddleware(ABC): + """ + Abstract class for middleware. + """ + + def __init__(self) -> None: + self._manager: Optional[MiddlewareManager] = None + + @property + def manager(self) -> MiddlewareManager: + """ + Instance of MiddlewareManager + """ + if self._manager is None: + raise RuntimeError("Middleware is not configured!") + return self._manager + + def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware: + """ + Mark middleware as configured + + :param manager: + :param _stack_level: + :return: + """ + if self.configured: + return manager.setup(self, _stack_level=_stack_level + 1) + + self._manager = manager + return self + + @property + def configured(self) -> bool: + """ + Check middleware is configured + + :return: + """ + return bool(self._manager) + + @abstractmethod + async def trigger( + self, + step: MiddlewareStep, + event_name: str, + event: UpdateType, + data: Dict[str, Any], + result: Any = None, + ) -> Any: # pragma: no cover + pass diff --git a/aiogram/dispatcher/middlewares/base.py b/aiogram/dispatcher/middlewares/base.py new file mode 100644 index 00000000..2ec921b7 --- /dev/null +++ b/aiogram/dispatcher/middlewares/base.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware +from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType + +if TYPE_CHECKING: # pragma: no cover + from aiogram.api.types import ( + CallbackQuery, + ChosenInlineResult, + InlineQuery, + Message, + Poll, + PollAnswer, + PreCheckoutQuery, + ShippingQuery, + Update, + ) + + +class BaseMiddleware(AbstractMiddleware): + """ + Base class for middleware. + + All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message". + """ + + async def trigger( + self, + step: MiddlewareStep, + event_name: str, + event: UpdateType, + data: Dict[str, Any], + result: Any = None, + ) -> Any: + """ + Trigger action. + + :param step: + :param event_name: + :param event: + :param data: + :param result: + :return: + """ + handler_name = f"on_{step.value}_{event_name}" + handler = getattr(self, handler_name, None) + if not handler: + return None + args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data) + return await handler(*args) + + if TYPE_CHECKING: # pragma: no cover + # ============================================================================================= + # Event that triggers before process + # ============================================================================================= + async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any: + """ + Event that triggers before process update + """ + + async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any: + """ + Event that triggers before process message + """ + + async def on_pre_process_edited_message( + self, edited_message: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process edited_message + """ + + async def on_pre_process_channel_post( + self, channel_post: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process channel_post + """ + + async def on_pre_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process edited_channel_post + """ + + async def on_pre_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process inline_query + """ + + async def on_pre_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process chosen_inline_result + """ + + async def on_pre_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process callback_query + """ + + async def on_pre_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process shipping_query + """ + + async def on_pre_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process pre_checkout_query + """ + + async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: + """ + Event that triggers before process poll + """ + + async def on_pre_process_poll_answer( + self, poll_answer: PollAnswer, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers before process poll_answer + """ + + # ============================================================================================= + # Event that triggers on process after filters. + # ============================================================================================= + async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any: + """ + Event that triggers on process update + """ + + async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any: + """ + Event that triggers on process message + """ + + async def on_process_edited_message( + self, edited_message: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process edited_message + """ + + async def on_process_channel_post( + self, channel_post: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process channel_post + """ + + async def on_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process edited_channel_post + """ + + async def on_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process inline_query + """ + + async def on_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process chosen_inline_result + """ + + async def on_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process callback_query + """ + + async def on_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process shipping_query + """ + + async def on_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process pre_checkout_query + """ + + async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: + """ + Event that triggers on process poll + """ + + async def on_process_poll_answer( + self, poll_answer: PollAnswer, data: Dict[str, Any] + ) -> Any: + """ + Event that triggers on process poll_answer + """ + + # ============================================================================================= + # Event that triggers after process . + # ============================================================================================= + async def on_post_process_update( + self, update: Update, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing update + """ + + async def on_post_process_message( + self, message: Message, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing message + """ + + async def on_post_process_edited_message( + self, edited_message: Message, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing edited_message + """ + + async def on_post_process_channel_post( + self, channel_post: Message, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing channel_post + """ + + async def on_post_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing edited_channel_post + """ + + async def on_post_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing inline_query + """ + + async def on_post_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing chosen_inline_result + """ + + async def on_post_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing callback_query + """ + + async def on_post_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing shipping_query + """ + + async def on_post_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing pre_checkout_query + """ + + async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any: + """ + Event that triggers after processing poll + """ + + async def on_post_process_poll_answer( + self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing poll_answer + """ diff --git a/aiogram/dispatcher/middlewares/manager.py b/aiogram/dispatcher/middlewares/manager.py new file mode 100644 index 00000000..39a6230d --- /dev/null +++ b/aiogram/dispatcher/middlewares/manager.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List +from warnings import warn + +from .abstract import AbstractMiddleware +from .types import MiddlewareStep, UpdateType + +if TYPE_CHECKING: # pragma: no cover + from aiogram.dispatcher.router import Router + + +class MiddlewareManager: + """ + Middleware manager. + """ + + def __init__(self, router: Router) -> None: + self.router = router + self.middlewares: List[AbstractMiddleware] = [] + + def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware: + """ + Setup middleware + + :param middleware: + :param _stack_level: + :return: + """ + if not isinstance(middleware, AbstractMiddleware): + raise TypeError( + f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}" + ) + if middleware.configured: + if middleware.manager is self: + warn( + f"Middleware {middleware} is already configured for this Router " + "That's mean re-installing of this middleware has no effect.", + category=RuntimeWarning, + stacklevel=_stack_level + 1, + ) + return middleware + raise ValueError( + f"Middleware is already configured for another manager {middleware.manager} " + f"in router {middleware.manager.router}!" + ) + + self.middlewares.append(middleware) + middleware.setup(self) + return middleware + + async def trigger( + self, + step: MiddlewareStep, + event_name: str, + event: UpdateType, + data: Dict[str, Any], + result: Any = None, + reverse: bool = False, + ) -> Any: + """ + Call action to middlewares with args lilt. + """ + middlewares = reversed(self.middlewares) if reverse else self.middlewares + for middleware in middlewares: + await middleware.trigger( + step=step, event_name=event_name, event=event, data=data, result=result + ) + + def __contains__(self, item: AbstractMiddleware) -> bool: + return item in self.middlewares diff --git a/aiogram/dispatcher/middlewares/types.py b/aiogram/dispatcher/middlewares/types.py new file mode 100644 index 00000000..3d1da420 --- /dev/null +++ b/aiogram/dispatcher/middlewares/types.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from enum import Enum +from typing import Union + +from aiogram.api.types import ( + CallbackQuery, + ChosenInlineResult, + InlineQuery, + Message, + Poll, + PollAnswer, + PreCheckoutQuery, + ShippingQuery, + Update, +) + +UpdateType = Union[ + CallbackQuery, + ChosenInlineResult, + InlineQuery, + Message, + Poll, + PollAnswer, + PreCheckoutQuery, + ShippingQuery, + Update, +] + + +class MiddlewareStep(Enum): + PRE_PROCESS = "pre_process" + PROCESS = "process" + POST_PROCESS = "post_process" diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 699c82c3..888117be 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -1,13 +1,15 @@ from __future__ import annotations import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union from ..api.types import Chat, TelegramObject, Update, User from ..utils.imports import import_module from ..utils.warnings import CodeHasNoEffect from .event.observer import EventObserver, SkipHandler, TelegramEventObserver from .filters import BUILTIN_FILTERS +from .middlewares.abstract import AbstractMiddleware +from .middlewares.manager import MiddlewareManager class Router: @@ -46,6 +48,7 @@ class Router: ) self.poll_handler = TelegramEventObserver(router=self, event_name="poll") self.poll_answer_handler = TelegramEventObserver(router=self, event_name="poll_answer") + self.middleware = MiddlewareManager(router=self) self.startup = EventObserver() self.shutdown = EventObserver() @@ -74,6 +77,36 @@ class Router: for builtin_filter in BUILTIN_FILTERS.get(name, ()): observer.bind_filter(builtin_filter) + @property + def chain_head(self) -> Generator[Router, None, None]: + router: Optional[Router] = self + while router: + yield router + router = router.parent_router + + @property + def chain_tail(self) -> Generator[Router, None, None]: + yield self + for router in self.sub_routers: + yield from router.chain_tail + + @property + def chain(self) -> Generator[Router, None, None]: + yield from self.chain_head + tail = self.chain_tail + next(tail) # Skip self + yield from tail + + def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware: + """ + Use middleware + + :param middleware: + :param _stack_level: + :return: + """ + return self.middleware.setup(middleware, _stack_level=_stack_level + 1) + @property def parent_router(self) -> Optional[Router]: return self._parent_router diff --git a/aiogram/loggers.py b/aiogram/loggers.py index 0352c0df..5b5a8eba 100644 --- a/aiogram/loggers.py +++ b/aiogram/loggers.py @@ -1,3 +1,4 @@ import logging dispatcher = logging.getLogger("aiogram.dispatcher") +middlewares = logging.getLogger("aiogram.middlewares") diff --git a/docs/assets/images/basics_middleware.png b/docs/assets/images/basics_middleware.png new file mode 100644 index 00000000..a797fd38 Binary files /dev/null and b/docs/assets/images/basics_middleware.png differ diff --git a/docs/assets/images/middleware_pipeline.png b/docs/assets/images/middleware_pipeline.png new file mode 100644 index 00000000..dcb20d6f Binary files /dev/null and b/docs/assets/images/middleware_pipeline.png differ diff --git a/docs/assets/images/middleware_pipeline_nested.png b/docs/assets/images/middleware_pipeline_nested.png new file mode 100644 index 00000000..f7a6195a Binary files /dev/null and b/docs/assets/images/middleware_pipeline_nested.png differ diff --git a/docs/dispatcher/middlewares/basics.md b/docs/dispatcher/middlewares/basics.md new file mode 100644 index 00000000..973ffe98 --- /dev/null +++ b/docs/dispatcher/middlewares/basics.md @@ -0,0 +1,111 @@ +# Basics + +All middlewares should be made with `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) as base class. + +For example: + +```python3 +class MyMiddleware(BaseMiddleware): ... +``` + +And then use next pattern in naming callback functions in middleware: `on_{step}_{event}` + +Where is: + +- `#!python3 step`: + - `#!python3 pre_process` + - `#!python3 process` + - `#!python3 post_process` +- `#!python3 event`: + - `#!python3 update` + - `#!python3 message` + - `#!python3 edited_message` + - `#!python3 channel_post` + - `#!python3 edited_channel_post` + - `#!python3 inline_query` + - `#!python3 chosen_inline_result` + - `#!python3 callback_query` + - `#!python3 shipping_query` + - `#!python3 pre_checkout_query` + - `#!python3 poll` + - `#!python3 poll_answer` + +## Connecting middleware with router + +Middlewares can be connected with router by next ways: + +1. `#!python3 router.use(MyMiddleware())` (**recommended**) +1. `#!python3 router.middleware.setup(MyMiddleware())` +1. `#!python3 MyMiddleware().setup(router.middleware)` (**not recommended**) + +!!! warning + One instance of middleware **can't** be registered twice in single or many middleware managers + +## The specification of step callbacks + +### Pre-process step + +| Argument | Type | Description | +| --- | --- | --- | +| event name | Any of event type (Update, Message and etc.) | Event | +| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | + +Returns `#!python3 Any` + +### Process step + +| Argument | Type | Description | +| --- | --- | --- | +| event name | Any of event type (Update, Message and etc.) | Event | +| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | + +Returns `#!python3 Any` + +### Post-Process step + +| Argument | Type | Description | +| --- | --- | --- | +| event name | Any of event type (Update, Message and etc.) | Event | +| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | +| `#!python3 result` | `#!python3 Dict[str, Any]` | Response from handlers | + +Returns `#!python3 Any` + +## Full list of available callbacks + +- `#!python3 on_pre_process_update` - will be triggered on **pre process** `#!python3 update` event +- `#!python3 on_process_update` - will be triggered on **process** `#!python3 update` event +- `#!python3 on_post_process_update` - will be triggered on **post process** `#!python3 update` event +- `#!python3 on_pre_process_message` - will be triggered on **pre process** `#!python3 message` event +- `#!python3 on_process_message` - will be triggered on **process** `#!python3 message` event +- `#!python3 on_post_process_message` - will be triggered on **post process** `#!python3 message` event +- `#!python3 on_pre_process_edited_message` - will be triggered on **pre process** `#!python3 edited_message` event +- `#!python3 on_process_edited_message` - will be triggered on **process** `#!python3 edited_message` event +- `#!python3 on_post_process_edited_message` - will be triggered on **post process** `#!python3 edited_message` event +- `#!python3 on_pre_process_channel_post` - will be triggered on **pre process** `#!python3 channel_post` event +- `#!python3 on_process_channel_post` - will be triggered on **process** `#!python3 channel_post` event +- `#!python3 on_post_process_channel_post` - will be triggered on **post process** `#!python3 channel_post` event +- `#!python3 on_pre_process_edited_channel_post` - will be triggered on **pre process** `#!python3 edited_channel_post` event +- `#!python3 on_process_edited_channel_post` - will be triggered on **process** `#!python3 edited_channel_post` event +- `#!python3 on_post_process_edited_channel_post` - will be triggered on **post process** `#!python3 edited_channel_post` event +- `#!python3 on_pre_process_inline_query` - will be triggered on **pre process** `#!python3 inline_query` event +- `#!python3 on_process_inline_query` - will be triggered on **process** `#!python3 inline_query` event +- `#!python3 on_post_process_inline_query` - will be triggered on **post process** `#!python3 inline_query` event +- `#!python3 on_pre_process_chosen_inline_result` - will be triggered on **pre process** `#!python3 chosen_inline_result` event +- `#!python3 on_process_chosen_inline_result` - will be triggered on **process** `#!python3 chosen_inline_result` event +- `#!python3 on_post_process_chosen_inline_result` - will be triggered on **post process** `#!python3 chosen_inline_result` event +- `#!python3 on_pre_process_callback_query` - will be triggered on **pre process** `#!python3 callback_query` event +- `#!python3 on_process_callback_query` - will be triggered on **process** `#!python3 callback_query` event +- `#!python3 on_post_process_callback_query` - will be triggered on **post process** `#!python3 callback_query` event +- `#!python3 on_pre_process_shipping_query` - will be triggered on **pre process** `#!python3 shipping_query` event +- `#!python3 on_process_shipping_query` - will be triggered on **process** `#!python3 shipping_query` event +- `#!python3 on_post_process_shipping_query` - will be triggered on **post process** `#!python3 shipping_query` event +- `#!python3 on_pre_process_pre_checkout_query` - will be triggered on **pre process** `#!python3 pre_checkout_query` event +- `#!python3 on_process_pre_checkout_query` - will be triggered on **process** `#!python3 pre_checkout_query` event +- `#!python3 on_post_process_pre_checkout_query` - will be triggered on **post process** `#!python3 pre_checkout_query` event +- `#!python3 on_pre_process_poll` - will be triggered on **pre process** `#!python3 poll` event +- `#!python3 on_process_poll` - will be triggered on **process** `#!python3 poll` event +- `#!python3 on_post_process_poll` - will be triggered on **post process** `#!python3 poll` event +- `#!python3 on_pre_process_poll_answer` - will be triggered on **pre process** `#!python3 poll_answer` event +- `#!python3 on_process_poll_answer` - will be triggered on **process** `#!python3 poll_answer` event +- `#!python3 on_post_process_poll_answer` - will be triggered on **post process** `#!python3 poll_answer` event diff --git a/docs/dispatcher/middlewares/index.md b/docs/dispatcher/middlewares/index.md new file mode 100644 index 00000000..114646d6 --- /dev/null +++ b/docs/dispatcher/middlewares/index.md @@ -0,0 +1,65 @@ +# Overview + +**aiogram**'s provides powerful mechanism for customizing event handlers via middlewares. + +Middlewares in bot framework seems like Middlewares mechanism in powerful web-frameworks +(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares), +[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/), +[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.) +with small difference - here is implemented many layers of processing +(named as [pipeline](#event-pipeline)). + +!!! info + Middleware is function that triggered on every event received from + Telegram Bot API in many points on processing pipeline. + +## Base theory + +As many books and other literature in internet says: +> Middleware is reusable software that leverages patterns and frameworks to bridge +>the gap between the functional requirements of applications and the underlying operating systems, +> network protocol stacks, and databases. + +Middleware can modify, extend or reject processing event before-, +on- or after- processing of that event. + +[![middlewares](../../assets/images/basics_middleware.png)](../../assets/images/basics_middleware.png) + +_(Click on image to zoom it)_ + +## Event pipeline + +As described below middleware an interact with event in many stages of pipeline. + +Simple workflow: + +1. Dispatcher receive an [Update](../../api/types/update.md) +1. Call **pre-process** update middleware in all routers tree +1. Filter Update over handlers +1. Call **process** update middleware in all routers tree +1. Router detects event type (Message, Callback query, etc.) +1. Router triggers **pre-process** middleware of specific type +1. Pass event over [filters](../filters/index.md) to detect specific handler +1. Call **process** middleware for specific type (only when handler for this event exists) +1. *Do magick*. Call handler (Read more [Event observers](../router.md#event-observers)) +1. Call **post-process** middleware +1. Call **post-process** update middleware in all routers tree +1. Emit response into webhook (when it needed) + +### Pipeline in pictures: + +#### Simple pipeline + +[![middlewares](../../assets/images/middleware_pipeline.png)](../../assets/images/middleware_pipeline.png) + +_(Click on image to zoom it)_ + +#### Nested routers pipeline + +[![middlewares](../../assets/images/middleware_pipeline_nested.png)](../../assets/images/middleware_pipeline_nested.png) + +_(Click on image to zoom it)_ + +## Read more + +- [Middleware Basics](basics.md) diff --git a/docs/index.md b/docs/index.md index 6961a181..57f6fa9b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,19 +15,19 @@ Documentation for version 3.0 [WIP] [^1] ## Features -- Asynchronous +- Asynchronous ([asyncio docs](https://docs.python.org/3/library/asyncio.html), [PEP-492](https://www.python.org/dev/peps/pep-0492/)) - [Supports Telegram Bot API v{!_api_version.md!}](api/index.md) - [Updates router](dispatcher/index.md) (Blueprints) - Finite State Machine -- Middlewares +- [Middlewares](dispatcher/middlewares/index.md) - [Replies into Webhook](https://core.telegram.org/bots/faq#how-can-i-make-requests-in-response-to-updates) !!! note Before start using **aiogram** is highly recommend to know how to work with [asyncio](https://docs.python.org/3/library/asyncio.html). - + Also if you has questions you can go to our community chats in Telegram: - + - [English language](https://t.me/aiogram) - [Russian language](https://t.me/aiogram_ru) diff --git a/docs/todo.md b/docs/todo.md index 02c99d9a..c06407f3 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -23,8 +23,8 @@ - [x] ContentTypes - [x] Text - [ ] ... - - [ ] Middlewares - - [ ] Engine + - [x] Middlewares + - [x] Engine - [ ] Builtin middlewares - [ ] ... - [ ] Webhook @@ -41,6 +41,7 @@ - [x] Dispatcher - [x] Router - [x] Observers + - [x] Middleware - [ ] Filters - [ ] Utils - [x] Helper diff --git a/mkdocs.yml b/mkdocs.yml index d64deb09..47631d6a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -17,7 +17,7 @@ theme: logo: 'assets/images/logo.png' extra: - version: 3.0.0a2 + version: 3.0.0a3 plugins: - search @@ -249,6 +249,9 @@ nav: - dispatcher/class_based_handlers/poll.md - dispatcher/class_based_handlers/pre_checkout_query.md - dispatcher/class_based_handlers/shipping_query.md + - Middlewares: + - dispatcher/middlewares/index.md + - dispatcher/middlewares/basics.md - todo.md - Build reports: - reports.md diff --git a/tests/test_api/test_types/test_message.py b/tests/test_api/test_types/test_message.py index 4b645e11..1254bd31 100644 --- a/tests/test_api/test_types/test_message.py +++ b/tests/test_api/test_types/test_message.py @@ -7,6 +7,7 @@ from aiogram.api.methods import ( SendAnimation, SendAudio, SendContact, + SendDice, SendDocument, SendGame, SendInvoice, @@ -26,6 +27,7 @@ from aiogram.api.types import ( Audio, Chat, Contact, + Dice, Document, EncryptedCredentials, Game, @@ -391,6 +393,16 @@ class TestMessage: ), ContentType.POLL, ], + [ + Message( + message_id=42, + date=datetime.datetime.now(), + chat=Chat(id=42, type="private"), + dice=Dice(value=6), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + ContentType.DICE, + ], [ Message( message_id=42, @@ -431,6 +443,7 @@ class TestMessage: ["", dict(text="test"), SendMessage], ["photo", dict(photo="photo"), SendPhoto], ["poll", dict(question="Q?", options=[]), SendPoll], + ["dice", dict(), SendDice], ["sticker", dict(sticker="sticker"), SendSticker], ["sticker", dict(sticker="sticker"), SendSticker], [ diff --git a/tests/test_dispatcher/test_middlewares/__init__.py b/tests/test_dispatcher/test_middlewares/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dispatcher/test_middlewares/test_base.py b/tests/test_dispatcher/test_middlewares/test_base.py new file mode 100644 index 00000000..203028ec --- /dev/null +++ b/tests/test_dispatcher/test_middlewares/test_base.py @@ -0,0 +1,241 @@ +import datetime +from typing import Any, Dict, Type + +import pytest + +from aiogram.api.types import ( + CallbackQuery, + Chat, + ChosenInlineResult, + InlineQuery, + Message, + Poll, + PollAnswer, + PreCheckoutQuery, + ShippingQuery, + Update, + User, +) +from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +class MyMiddleware(BaseMiddleware): + async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any: + return "update" + + async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any: + return "message" + + async def on_pre_process_edited_message( + self, edited_message: Message, data: Dict[str, Any] + ) -> Any: + return "edited_message" + + async def on_pre_process_channel_post( + self, channel_post: Message, data: Dict[str, Any] + ) -> Any: + return "channel_post" + + async def on_pre_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any] + ) -> Any: + return "edited_channel_post" + + async def on_pre_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any] + ) -> Any: + return "inline_query" + + async def on_pre_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] + ) -> Any: + return "chosen_inline_result" + + async def on_pre_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any] + ) -> Any: + return "callback_query" + + async def on_pre_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any] + ) -> Any: + return "shipping_query" + + async def on_pre_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] + ) -> Any: + return "pre_checkout_query" + + async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: + return "poll" + + async def on_pre_process_poll_answer( + self, poll_answer: PollAnswer, data: Dict[str, Any] + ) -> Any: + return "poll_answer" + + async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any: + return "update" + + async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any: + return "message" + + async def on_process_edited_message( + self, edited_message: Message, data: Dict[str, Any] + ) -> Any: + return "edited_message" + + async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any: + return "channel_post" + + async def on_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any] + ) -> Any: + return "edited_channel_post" + + async def on_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any] + ) -> Any: + return "inline_query" + + async def on_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] + ) -> Any: + return "chosen_inline_result" + + async def on_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any] + ) -> Any: + return "callback_query" + + async def on_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any] + ) -> Any: + return "shipping_query" + + async def on_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] + ) -> Any: + return "pre_checkout_query" + + async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: + return "poll" + + async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any: + return "poll_answer" + + async def on_post_process_update( + self, update: Update, data: Dict[str, Any], result: Any + ) -> Any: + return "update" + + async def on_post_process_message( + self, message: Message, data: Dict[str, Any], result: Any + ) -> Any: + return "message" + + async def on_post_process_edited_message( + self, edited_message: Message, data: Dict[str, Any], result: Any + ) -> Any: + return "edited_message" + + async def on_post_process_channel_post( + self, channel_post: Message, data: Dict[str, Any], result: Any + ) -> Any: + return "channel_post" + + async def on_post_process_edited_channel_post( + self, edited_channel_post: Message, data: Dict[str, Any], result: Any + ) -> Any: + return "edited_channel_post" + + async def on_post_process_inline_query( + self, inline_query: InlineQuery, data: Dict[str, Any], result: Any + ) -> Any: + return "inline_query" + + async def on_post_process_chosen_inline_result( + self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any + ) -> Any: + return "chosen_inline_result" + + async def on_post_process_callback_query( + self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any + ) -> Any: + return "callback_query" + + async def on_post_process_shipping_query( + self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any + ) -> Any: + return "shipping_query" + + async def on_post_process_pre_checkout_query( + self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any + ) -> Any: + return "pre_checkout_query" + + async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any: + return "poll" + + async def on_post_process_poll_answer( + self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any + ) -> Any: + return "poll_answer" + + +UPDATE = Update(update_id=42) +MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private")) +POLL_ANSWER = PollAnswer( + poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0] +) + + +class TestBaseMiddleware: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]] + ) + @pytest.mark.parametrize( + "step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS] + ) + @pytest.mark.parametrize( + "event_name,event", + [["update", UPDATE], ["message", MESSAGE], ["poll_answer", POLL_ANSWER],], + ) + async def test_trigger( + self, + step: MiddlewareStep, + event_name: str, + event: UpdateType, + middleware_cls: Type[BaseMiddleware], + should_be_awaited: bool, + ): + middleware = middleware_cls() + + with patch( + f"tests.test_dispatcher.test_middlewares.test_base." + f"MyMiddleware.on_{step.value}_{event_name}", + new_callable=CoroutineMock, + ) as mocked_call: + response = await middleware.trigger( + step=step, event_name=event_name, event=event, data={} + ) + if should_be_awaited: + mocked_call.assert_awaited() + assert response is not None + else: + mocked_call.assert_not_awaited() + assert response is None + + def test_not_configured(self): + middleware = BaseMiddleware() + assert not middleware.configured + + with pytest.raises(RuntimeError): + manager = middleware.manager diff --git a/tests/test_dispatcher/test_middlewares/test_manager.py b/tests/test_dispatcher/test_middlewares/test_manager.py new file mode 100644 index 00000000..0e23f1b2 --- /dev/null +++ b/tests/test_dispatcher/test_middlewares/test_manager.py @@ -0,0 +1,82 @@ +import pytest + +from aiogram import Router +from aiogram.api.types import Update +from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.dispatcher.middlewares.manager import MiddlewareManager +from aiogram.dispatcher.middlewares.types import MiddlewareStep + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +@pytest.fixture("function") +def router(): + return Router() + + +@pytest.fixture("function") +def manager(router: Router): + return MiddlewareManager(router) + + +class TestManager: + def test_setup(self, manager: MiddlewareManager): + middleware = BaseMiddleware() + returned = manager.setup(middleware) + assert returned is middleware + assert middleware.configured + assert middleware.manager is manager + assert middleware in manager + + @pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware]) + def test_setup_invalid_type(self, manager: MiddlewareManager, obj): + with pytest.raises(TypeError): + assert manager.setup(obj) + + def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router): + middleware = BaseMiddleware() + manager.setup(middleware) + + assert middleware.configured + + new_manager = MiddlewareManager(router) + with pytest.raises(ValueError): + new_manager.setup(middleware) + with pytest.raises(ValueError): + middleware.setup(new_manager) + + def test_configure_twice(self, manager: MiddlewareManager): + middleware = BaseMiddleware() + manager.setup(middleware) + + assert middleware.configured + + with pytest.warns(RuntimeWarning, match="is already configured for this Router"): + manager.setup(middleware) + + with pytest.warns(RuntimeWarning, match="is already configured for this Router"): + middleware.setup(manager) + + @pytest.mark.asyncio + @pytest.mark.parametrize("count", range(5)) + async def test_trigger(self, manager: MiddlewareManager, count: int): + for _ in range(count): + manager.setup(BaseMiddleware()) + + with patch( + "aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger", + new_callable=CoroutineMock, + ) as mocked_call: + await manager.trigger( + step=MiddlewareStep.PROCESS, + event_name="update", + event=Update(update_id=42), + data={}, + result=None, + reverse=True, + ) + + assert mocked_call.await_count == count diff --git a/tests/test_dispatcher/test_router.py b/tests/test_dispatcher/test_router.py index ca66c1ad..eacb8d0c 100644 --- a/tests/test_dispatcher/test_router.py +++ b/tests/test_dispatcher/test_router.py @@ -18,6 +18,7 @@ from aiogram.api.types import ( User, ) from aiogram.dispatcher.event.observer import SkipHandler +from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.router import Router from aiogram.utils.warnings import CodeHasNoEffect @@ -407,3 +408,11 @@ class TestRouter: await router1.emit_shutdown() assert results == [2, 1, 2] + + def test_use(self): + router = Router() + + middleware = router.use(BaseMiddleware()) + assert isinstance(middleware, BaseMiddleware) + assert middleware.configured + assert middleware.manager == router.middleware diff --git a/tests/test_utils/test_markdown.py b/tests/test_utils/test_markdown.py index b9da8f46..792c1bb4 100644 --- a/tests/test_utils/test_markdown.py +++ b/tests/test_utils/test_markdown.py @@ -2,37 +2,54 @@ from typing import Any, Callable, Optional, Tuple import pytest -from aiogram.utils import markdown +from aiogram.utils.markdown import ( + bold, + code, + hbold, + hcode, + hide_link, + hitalic, + hlink, + hpre, + hstrikethrough, + hunderline, + italic, + link, + pre, + strikethrough, + text, + underline, +) class TestMarkdown: @pytest.mark.parametrize( "func,args,sep,result", [ - [markdown.text, ("test", "test"), " ", "test test"], - [markdown.text, ("test", "test"), "\n", "test\ntest"], - [markdown.text, ("test", "test"), None, "test test"], - [markdown.bold, ("test", "test"), " ", "*test test*"], - [markdown.hbold, ("test", "test"), " ", "test test"], - [markdown.italic, ("test", "test"), " ", "_test test_\r"], - [markdown.hitalic, ("test", "test"), " ", "test test"], - [markdown.code, ("test", "test"), " ", "`test test`"], - [markdown.hcode, ("test", "test"), " ", "test test"], - [markdown.pre, ("test", "test"), " ", "```test test```"], - [markdown.hpre, ("test", "test"), " ", "
test test
"], - [markdown.underline, ("test", "test"), " ", "__test test__"], - [markdown.hunderline, ("test", "test"), " ", "test test"], - [markdown.strikethrough, ("test", "test"), " ", "~test test~"], - [markdown.hstrikethrough, ("test", "test"), " ", "test test"], - [markdown.link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"], + [text, ("test", "test"), " ", "test test"], + [text, ("test", "test"), "\n", "test\ntest"], + [text, ("test", "test"), None, "test test"], + [bold, ("test", "test"), " ", "*test test*"], + [hbold, ("test", "test"), " ", "test test"], + [italic, ("test", "test"), " ", "_test test_\r"], + [hitalic, ("test", "test"), " ", "test test"], + [code, ("test", "test"), " ", "`test test`"], + [hcode, ("test", "test"), " ", "test test"], + [pre, ("test", "test"), " ", "```test test```"], + [hpre, ("test", "test"), " ", "
test test
"], + [underline, ("test", "test"), " ", "__test test__"], + [hunderline, ("test", "test"), " ", "test test"], + [strikethrough, ("test", "test"), " ", "~test test~"], + [hstrikethrough, ("test", "test"), " ", "test test"], + [link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"], [ - markdown.hlink, + hlink, ("test", "https://aiogram.dev"), None, 'test', ], [ - markdown.hide_link, + hide_link, ("https://aiogram.dev",), None, '',