Merge pull request #337 from aiogram/dev-3.x-middlewares

Implement new middlewares
This commit is contained in:
Alex Root Junior 2020-05-26 22:06:21 +03:00 committed by GitHub
commit a627c75bab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 618 additions and 1434 deletions

View file

@ -6,7 +6,6 @@ python := $(py) python
reports_dir := reports reports_dir := reports
.PHONY: help
help: help:
@echo "=======================================================================================" @echo "======================================================================================="
@echo " aiogram build tools " @echo " aiogram build tools "
@ -45,12 +44,10 @@ help:
# Environment # Environment
# ================================================================================================= # =================================================================================================
.PHONY: install
install: install:
$(base_python) -m pip install --user -U poetry $(base_python) -m pip install --user -U poetry
poetry install poetry install
.PHONY: clean
clean: clean:
rm -rf `find . -name __pycache__` rm -rf `find . -name __pycache__`
rm -f `find . -type f -name '*.py[co]' ` rm -f `find . -type f -name '*.py[co]' `
@ -68,65 +65,56 @@ clean:
# Code quality # Code quality
# ================================================================================================= # =================================================================================================
.PHONY: isort
isort: isort:
$(py) isort -rc aiogram tests $(py) isort -rc aiogram tests
.PHONY: black
black: black:
$(py) black aiogram tests $(py) black aiogram tests
.PHONY: flake8
flake8: flake8:
$(py) flake8 aiogram test $(py) flake8 aiogram test
.PHONY: flake8-report
flake8-report: flake8-report:
mkdir -p $(reports_dir)/flake8 mkdir -p $(reports_dir)/flake8
$(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram test $(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram test
.PHONY: mypy
mypy: mypy:
$(py) mypy aiogram $(py) mypy aiogram
.PHONY: mypy-report
mypy-report: mypy-report:
$(py) mypy aiogram --html-report $(reports_dir)/typechecking $(py) mypy aiogram --html-report $(reports_dir)/typechecking
.PHONY: lint
lint: isort black flake8 mypy lint: isort black flake8 mypy
# ================================================================================================= # =================================================================================================
# Tests # Tests
# ================================================================================================= # =================================================================================================
.PHONY: test
test: test:
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/ $(py) pytest --cov=aiogram --cov-config .coveragerc tests/
.PHONY: test-coverage
test-coverage: test-coverage:
mkdir -p $(reports_dir)/tests/ mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/
test-coverage-report:
$(py) coverage html -d $(reports_dir)/coverage $(py) coverage html -d $(reports_dir)/coverage
.PHONY: test-coverage-report test-coverage-view:
test-coverage-report: $(py) coverage html -d $(reports_dir)/coverage
python -c "import webbrowser; webbrowser.open('file://$(shell pwd)/reports/coverage/index.html')" python -c "import webbrowser; webbrowser.open('file://$(shell pwd)/reports/coverage/index.html')"
# ================================================================================================= # =================================================================================================
# Docs # Docs
# ================================================================================================= # =================================================================================================
.PHONY: docs
docs: docs:
$(py) mkdocs build $(py) mkdocs build
.PHONY: docs-serve
docs-serve: docs-serve:
$(py) mkdocs serve $(py) mkdocs serve
.PHONY: docs-copy-reports
docs-copy-reports: docs-copy-reports:
mv $(reports_dir)/* site/reports mv $(reports_dir)/* site/reports
@ -134,9 +122,7 @@ docs-copy-reports:
# Project # Project
# ================================================================================================= # =================================================================================================
.PHONY: build
build: clean flake8-report mypy-report test-coverage docs docs-copy-reports build: clean flake8-report mypy-report test-coverage docs docs-copy-reports
mkdir -p site/simple mkdir -p site/simple
poetry build poetry build
mv dist site/simple/aiogram mv dist site/simple/aiogram

View file

@ -1,3 +1,5 @@
from pkg_resources import get_distribution
from .api import methods, types from .api import methods, types
from .api.client import session from .api.client import session
from .api.client.bot import Bot from .api.client.bot import Bot
@ -28,5 +30,5 @@ __all__ = (
"handler", "handler",
) )
__version__ = "3.0.0a4" __version__ = get_distribution(dist=__package__).version
__api_version__ = "4.8" __api_version__ = "4.8"

View file

@ -11,6 +11,8 @@ from ..api.client.bot import Bot
from ..api.methods import TelegramMethod from ..api.methods import TelegramMethod
from ..api.types import Update, User from ..api.types import Update, User
from ..utils.exceptions import TelegramAPIError from ..utils.exceptions import TelegramAPIError
from .event.bases import NOT_HANDLED
from .middlewares.user_context import UserContextMiddleware
from .router import Router from .router import Router
@ -23,6 +25,9 @@ class Dispatcher(Router):
super(Dispatcher, self).__init__(**kwargs) super(Dispatcher, self).__init__(**kwargs)
self._running_lock = Lock() self._running_lock = Lock()
# Default middleware is needed for contextual features
self.update.outer_middleware(UserContextMiddleware())
@property @property
def parent_router(self) -> None: def parent_router(self) -> None:
""" """
@ -42,9 +47,7 @@ class Dispatcher(Router):
""" """
raise RuntimeError("Dispatcher can not be attached to another Router.") raise RuntimeError("Dispatcher can not be attached to another Router.")
async def feed_update( async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
self, bot: Bot, update: Update, **kwargs: Any
) -> AsyncGenerator[Any, None]:
""" """
Main entry point for incoming updates Main entry point for incoming updates
@ -57,9 +60,9 @@ class Dispatcher(Router):
Bot.set_current(bot) Bot.set_current(bot)
try: try:
async for result in self.update.trigger(update, bot=bot, **kwargs): response = await self.update.trigger(update, bot=bot, **kwargs)
handled = True handled = response is not NOT_HANDLED
yield result return response
finally: finally:
finish_time = loop.time() finish_time = loop.time()
duration = (finish_time - start_time) * 1000 duration = (finish_time - start_time) * 1000
@ -71,9 +74,7 @@ class Dispatcher(Router):
bot.id, bot.id,
) )
async def feed_raw_update( async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
self, bot: Bot, update: Dict[str, Any], **kwargs: Any
) -> AsyncGenerator[Any, None]:
""" """
Main entry point for incoming updates with automatic Dict->Update serializer Main entry point for incoming updates with automatic Dict->Update serializer
@ -82,8 +83,7 @@ class Dispatcher(Router):
:param kwargs: :param kwargs:
""" """
parsed_update = Update(**update) parsed_update = Update(**update)
async for result in self.feed_update(bot=bot, update=parsed_update, **kwargs): return await self.feed_update(bot=bot, update=parsed_update, **kwargs)
yield result
@classmethod @classmethod
async def _listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]: async def _listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]:
@ -114,7 +114,7 @@ class Dispatcher(Router):
# For debugging here is added logging. # For debugging here is added logging.
loggers.dispatcher.error("Failed to make answer: %s: %s", e.__class__.__name__, e) loggers.dispatcher.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
async def process_update( async def _process_update(
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any
) -> bool: ) -> bool:
""" """
@ -126,11 +126,13 @@ class Dispatcher(Router):
:param kwargs: contextual data for middlewares, filters and handlers :param kwargs: contextual data for middlewares, filters and handlers
:return: status :return: status
""" """
handled = False
try: try:
async for result in self.feed_update(bot, update, **kwargs): response = await self.feed_update(bot, update, **kwargs)
if call_answer and isinstance(result, TelegramMethod): handled = handled is not NOT_HANDLED
await self._silent_call_request(bot=bot, result=result) if call_answer and isinstance(response, TelegramMethod):
return True await self._silent_call_request(bot=bot, result=response)
return handled
except Exception as e: except Exception as e:
loggers.dispatcher.exception( loggers.dispatcher.exception(
@ -142,8 +144,6 @@ class Dispatcher(Router):
) )
return True # because update was processed but unsuccessful return True # because update was processed but unsuccessful
return False
async def _polling(self, bot: Bot, **kwargs: Any) -> None: async def _polling(self, bot: Bot, **kwargs: Any) -> None:
""" """
Internal polling process Internal polling process
@ -153,16 +153,14 @@ class Dispatcher(Router):
:return: :return:
""" """
async for update in self._listen_updates(bot): async for update in self._listen_updates(bot):
await self.process_update(bot=bot, update=update, **kwargs) await self._process_update(bot=bot, update=update, **kwargs)
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
""" """
The same with `Dispatcher.process_update()` but returns real response instead of bool The same with `Dispatcher.process_update()` but returns real response instead of bool
""" """
try: try:
async for result in self.feed_update(bot, update, **kwargs): return await self.feed_update(bot, update, **kwargs)
return result
except Exception as e: except Exception as e:
loggers.dispatcher.exception( loggers.dispatcher.exception(
"Cause exception while process update id=%d by bot id=%d\n%s: %s", "Cause exception while process update id=%d by bot id=%d\n%s: %s",
@ -196,10 +194,10 @@ class Dispatcher(Router):
def process_response(task: Future[Any]) -> None: def process_response(task: Future[Any]) -> None:
warnings.warn( warnings.warn(
f"Detected slow response into webhook.\n" "Detected slow response into webhook.\n"
f"Telegram is waiting for response only first 60 seconds and then re-send update.\n" "Telegram is waiting for response only first 60 seconds and then re-send update.\n"
f"For preventing this situation response into webhook returned immediately " "For preventing this situation response into webhook returned immediately "
f"and handler is moved to background and still processing update.", "and handler is moved to background and still processing update.",
RuntimeWarning, RuntimeWarning,
) )
try: try:

View file

@ -0,0 +1,29 @@
from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
from unittest.mock import sentinel
from ...api.types import TelegramObject
from ..middlewares.base import BaseMiddleware
NextMiddlewareType = Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]]
MiddlewareType = Union[
BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]]
]
NOT_HANDLED = sentinel.NOT_HANDLED
class SkipHandler(Exception):
pass
class CancelHandler(Exception):
pass
def skip(message: Optional[str] = None) -> NoReturn:
"""
Raise an SkipHandler
"""
raise SkipHandler(message or "Event skipped")

View file

@ -0,0 +1,39 @@
from __future__ import annotations
from typing import Any, Callable, List
from .handler import CallbackType, HandlerObject, HandlerType
class EventObserver:
"""
Simple events observer
"""
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType) -> None:
"""
Register callback with filters
"""
self.handlers.append(HandlerObject(callback=callback))
async def trigger(self, *args: Any, **kwargs: Any) -> None:
"""
Propagate event to handlers.
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
await handler.call(*args, **kwargs)
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback)
return callback
return wrapper

View file

@ -1,3 +1,5 @@
import asyncio
import contextvars
import inspect import inspect
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
@ -6,9 +8,9 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type,
from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler from aiogram.dispatcher.handler.base import BaseHandler
CallbackType = Callable[[Any], Awaitable[Any]] CallbackType = Callable[..., Awaitable[Any]]
SyncFilter = Callable[[Any], Any] SyncFilter = Callable[..., Any]
AsyncFilter = Callable[[Any], Awaitable[Any]] AsyncFilter = Callable[..., Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[FilterType, Type[BaseHandler]] HandlerType = Union[FilterType, Type[BaseHandler]]
@ -40,7 +42,11 @@ class CallableMixin:
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
if self.awaitable: if self.awaitable:
return await wrapped() return await wrapped()
return wrapped()
loop = asyncio.get_event_loop()
context = contextvars.copy_context()
wrapped = partial(context.run, wrapped)
return await loop.run_in_executor(None, wrapped)
@dataclass @dataclass
@ -60,11 +66,11 @@ class HandlerObject(CallableMixin):
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters: if not self.filters:
return True, {} return True, kwargs
for event_filter in self.filters: for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs) check = await event_filter.call(*args, **kwargs)
if not check: if not check:
return False, {} return False, kwargs
if isinstance(check, dict): if isinstance(check, dict):
kwargs.update(check) kwargs.update(check)
return True, kwargs return True, kwargs

View file

@ -1,93 +1,33 @@
from __future__ import annotations from __future__ import annotations
import functools
from itertools import chain from itertools import chain
from typing import ( from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
NoReturn,
Optional,
Type,
)
from pydantic import ValidationError from pydantic import ValidationError
from ...api.types import TelegramObject
from ..filters.base import BaseFilter from ..filters.base import BaseFilter
from ..middlewares.types import MiddlewareStep, UpdateType from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
class SkipHandler(Exception): class TelegramEventObserver:
pass
class CancelHandler(Exception):
pass
def skip(message: Optional[str] = None) -> NoReturn:
"""
Raise an SkipHandler
"""
raise SkipHandler(message or "Event skipped")
class EventObserver:
"""
Base events observer
"""
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType) -> HandlerType:
"""
Register callback with filters
"""
self.handlers.append(HandlerObject(callback=callback))
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
"""
Propagate event to handlers.
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback)
return callback
return wrapper
class TelegramEventObserver(EventObserver):
""" """
Event observer for Telegram events Event observer for Telegram events
""" """
def __init__(self, router: Router, event_name: str) -> None: def __init__(self, router: Router, event_name: str) -> None:
super().__init__()
self.router: Router = router self.router: Router = router
self.event_name: str = event_name self.event_name: str = event_name
self.handlers: List[HandlerObject] = []
self.filters: List[Type[BaseFilter]] = [] self.filters: List[Type[BaseFilter]] = []
self.outer_middlewares: List[MiddlewareType] = []
self.middlewares: List[MiddlewareType] = []
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
""" """
@ -144,37 +84,6 @@ class TelegramEventObserver(EventObserver):
return filters return filters
async def trigger_middleware(
self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None,
) -> None:
"""
Trigger middlewares chain
:param step:
:param event:
:param data:
:param result:
:return:
"""
reverse = step == MiddlewareStep.POST_PROCESS
recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS
if self.event_name == "update":
routers = self.router.chain
else:
routers = self.router.chain_head
for router in routers:
await router.middleware.trigger(
step=step,
event_name=self.event_name,
event=event,
data=data,
result=result,
reverse=reverse,
)
if not recursive:
break
def register( def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
) -> HandlerType: ) -> HandlerType:
@ -190,32 +99,39 @@ class TelegramEventObserver(EventObserver):
) )
return callback return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: @classmethod
def _wrap_middleware(
cls, middlewares: List[MiddlewareType], handler: HandlerType
) -> NextMiddlewareType:
@functools.wraps(handler)
def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = mapper
for m in reversed(middlewares):
middleware = functools.partial(m, middleware)
return middleware
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
""" """
Propagate event to handlers and stops propagation on first match. Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass. Handler will be called when all its filters is pass.
""" """
event = args[0] wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger)
await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs) return await wrapped_outer(event, kwargs)
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
for handler in self.handlers: for handler in self.handlers:
result, data = await handler.check(*args, **kwargs) result, data = await handler.check(event, **kwargs)
if result: if result:
kwargs.update(data) kwargs.update(data)
await self.trigger_middleware(
step=MiddlewareStep.PROCESS, event=event, data=kwargs
)
try: try:
response = await handler.call(*args, **kwargs) wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
await self.trigger_middleware( return await wrapped_inner(event, kwargs)
step=MiddlewareStep.POST_PROCESS,
event=event,
data=kwargs,
result=response,
)
yield response
except SkipHandler: except SkipHandler:
continue continue
break
return NOT_HANDLED
def __call__( def __call__(
self, *args: FilterType, **bound_filters: BaseFilter self, *args: FilterType, **bound_filters: BaseFilter
@ -229,3 +145,45 @@ class TelegramEventObserver(EventObserver):
return callback return callback
return wrapper return wrapper
def middleware(
self, middleware: Optional[MiddlewareType] = None,
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
"""
Decorator for registering inner middlewares
Usage:
>>> @<event>.middleware() # via decorator (variant 1)
>>> @<event>.middleware # via decorator (variant 2)
>>> async def my_middleware(handler, event, data): ...
>>> <event>.middleware(middleware) # via method
"""
def wrapper(m: MiddlewareType) -> MiddlewareType:
self.middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)
def outer_middleware(
self, middleware: Optional[MiddlewareType] = None,
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
"""
Decorator for registering outer middlewares
Usage:
>>> @<event>.outer_middleware() # via decorator (variant 1)
>>> @<event>.outer_middleware # via decorator (variant 2)
>>> async def my_middleware(handler, event, data): ...
>>> <event>.outer_middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType) -> MiddlewareType:
self.outer_middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)

View file

@ -1,61 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
class AbstractMiddleware(ABC):
"""
Abstract class for middleware.
"""
def __init__(self) -> None:
self._manager: Optional[MiddlewareManager] = None
@property
def manager(self) -> MiddlewareManager:
"""
Instance of MiddlewareManager
"""
if self._manager is None:
raise RuntimeError("Middleware is not configured!")
return self._manager
def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware:
"""
Mark middleware as configured
:param manager:
:param _stack_level:
:return:
"""
if self.configured:
return manager.setup(self, _stack_level=_stack_level + 1)
self._manager = manager
return self
@property
def configured(self) -> bool:
"""
Check middleware is configured
:return:
"""
return bool(self._manager)
@abstractmethod
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
) -> Any: # pragma: no cover
pass

View file

@ -1,317 +1,15 @@
from __future__ import annotations from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Dict T = TypeVar("T")
from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
class BaseMiddleware(AbstractMiddleware): class BaseMiddleware(ABC, Generic[T]):
""" @abstractmethod
Base class for middleware. async def __call__(
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
"""
async def trigger(
self, self,
step: MiddlewareStep, handler: Callable[[T, Dict[str, Any]], Awaitable[Any]],
event_name: str, event: T,
event: UpdateType,
data: Dict[str, Any], data: Dict[str, Any],
result: Any = None, ) -> Any: # pragma: no cover
) -> Any: pass
"""
Trigger action.
:param step:
:param event_name:
:param event:
:param data:
:param result:
:return:
"""
handler_name = f"on_{step.value}_{event_name}"
handler = getattr(self, handler_name, None)
if not handler:
return None
args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data)
return await handler(*args)
if TYPE_CHECKING: # pragma: no cover
# =============================================================================================
# Event that triggers before process <event>
# =============================================================================================
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process update
"""
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process message
"""
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_message
"""
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process channel_post
"""
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_channel_post
"""
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process inline_query
"""
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process chosen_inline_result
"""
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process callback_query
"""
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process shipping_query
"""
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process pre_checkout_query
"""
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process poll
"""
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process poll_answer
"""
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process error
"""
# =============================================================================================
# Event that triggers on process <event> after filters.
# =============================================================================================
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process update
"""
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process message
"""
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_message
"""
async def on_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process channel_post
"""
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_channel_post
"""
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process inline_query
"""
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process chosen_inline_result
"""
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process callback_query
"""
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process shipping_query
"""
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process pre_checkout_query
"""
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process poll
"""
async def on_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process poll_answer
"""
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process error
"""
# =============================================================================================
# Event that triggers after process <event>.
# =============================================================================================
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing update
"""
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing message
"""
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_message
"""
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing channel_post
"""
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_channel_post
"""
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing inline_query
"""
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing chosen_inline_result
"""
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing callback_query
"""
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing shipping_query
"""
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing pre_checkout_query
"""
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
"""
Event that triggers after processing poll
"""
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing poll_answer
"""
async def on_post_process_error(
self, exception: Exception, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing error
"""

View file

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
from ...api.types import Update
from ..event.bases import NOT_HANDLED, CancelHandler, SkipHandler
from .base import BaseMiddleware
if TYPE_CHECKING: # pragma: no cover
from ..router import Router
class ErrorsMiddleware(BaseMiddleware[Update]):
def __init__(self, router: Router):
self.router = router
async def __call__(
self,
handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]],
event: Any,
data: Dict[str, Any],
) -> Any:
try:
return await handler(event, data)
except (SkipHandler, CancelHandler): # pragma: no cover
raise
except Exception as e:
response = await self.router.errors.trigger(event, exception=e, **data)
if response is NOT_HANDLED:
raise
return response

View file

@ -1,71 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from warnings import warn
from .abstract import AbstractMiddleware
from .types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
class MiddlewareManager:
"""
Middleware manager.
"""
def __init__(self, router: Router) -> None:
self.router = router
self.middlewares: List[AbstractMiddleware] = []
def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Setup middleware
:param middleware:
:param _stack_level:
:return:
"""
if not isinstance(middleware, AbstractMiddleware):
raise TypeError(
f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}"
)
if middleware.configured:
if middleware.manager is self:
warn(
f"Middleware {middleware} is already configured for this Router "
"That's mean re-installing of this middleware has no effect.",
category=RuntimeWarning,
stacklevel=_stack_level + 1,
)
return middleware
raise ValueError(
f"Middleware is already configured for another manager {middleware.manager} "
f"in router {middleware.manager.router}!"
)
self.middlewares.append(middleware)
middleware.setup(self)
return middleware
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
reverse: bool = False,
) -> Any:
"""
Call action to middlewares with args lilt.
"""
middlewares = reversed(self.middlewares) if reverse else self.middlewares
for middleware in middlewares:
await middleware.trigger(
step=step, event_name=event_name, event=event, data=data, result=result
)
def __contains__(self, item: AbstractMiddleware) -> bool:
return item in self.middlewares

View file

@ -1,35 +0,0 @@
from __future__ import annotations
from enum import Enum
from typing import Union
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
UpdateType = Union[
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
BaseException,
]
class MiddlewareStep(Enum):
PRE_PROCESS = "pre_process"
PROCESS = "process"
POST_PROCESS = "post_process"

View file

@ -0,0 +1,62 @@
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
from aiogram.api.types import Chat, Update, User
from aiogram.dispatcher.middlewares.base import BaseMiddleware
class UserContextMiddleware(BaseMiddleware[Update]):
async def __call__(
self,
handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]],
event: Update,
data: Dict[str, Any],
) -> Any:
chat, user = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user):
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
@classmethod
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
"""
Resolve chat and user instance from Update object
"""
if event.message:
return event.message.chat, event.message.from_user
if event.edited_message:
return event.edited_message.chat, event.edited_message.from_user
if event.channel_post:
return event.channel_post.chat, None
if event.edited_channel_post:
return event.edited_channel_post.chat, None
if event.inline_query:
return None, event.inline_query.from_user
if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user
if event.callback_query:
if event.callback_query.message:
return event.callback_query.message.chat, event.callback_query.from_user
return None, event.callback_query.from_user
if event.shipping_query:
return None, event.shipping_query.from_user
if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user
if event.poll_answer:
return None, event.poll_answer.user
return None, None

View file

@ -3,13 +3,14 @@ from __future__ import annotations
import warnings import warnings
from typing import Any, Dict, Generator, List, Optional, Union from typing import Any, Dict, Generator, List, Optional, Union
from ..api.types import Chat, TelegramObject, Update, User from ..api.types import TelegramObject, Update
from ..utils.imports import import_module from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect from ..utils.warnings import CodeHasNoEffect
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver from .event.bases import NOT_HANDLED, SkipHandler
from .event.event import EventObserver
from .event.telegram import TelegramEventObserver
from .filters import BUILTIN_FILTERS from .filters import BUILTIN_FILTERS
from .middlewares.abstract import AbstractMiddleware from .middlewares.error import ErrorsMiddleware
from .middlewares.manager import MiddlewareManager
class Router: class Router:
@ -44,8 +45,6 @@ class Router:
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer") self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
self.errors = TelegramEventObserver(router=self, event_name="error") self.errors = TelegramEventObserver(router=self, event_name="error")
self.middleware = MiddlewareManager(router=self)
self.startup = EventObserver() self.startup = EventObserver()
self.shutdown = EventObserver() self.shutdown = EventObserver()
@ -68,6 +67,8 @@ class Router:
# Root handler # Root handler
self.update.register(self._listen_update) self.update.register(self._listen_update)
self.update.outer_middleware(ErrorsMiddleware(self))
# Builtin filters # Builtin filters
if use_builtin_filters: if use_builtin_filters:
for name, observer in self.observers.items(): for name, observer in self.observers.items():
@ -94,16 +95,6 @@ class Router:
next(tail) # Skip self next(tail) # Skip self
yield from tail yield from tail
def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Use middleware
:param middleware:
:param _stack_level:
:return:
"""
return self.middleware.setup(middleware, _stack_level=_stack_level + 1)
@property @property
def parent_router(self) -> Optional[Router]: def parent_router(self) -> Optional[Router]:
return self._parent_router return self._parent_router
@ -176,53 +167,40 @@ class Router:
:param kwargs: :param kwargs:
:return: :return:
""" """
chat: Optional[Chat] = None
from_user: Optional[User] = None
event: TelegramObject event: TelegramObject
if update.message: if update.message:
update_type = "message" update_type = "message"
from_user = update.message.from_user
chat = update.message.chat
event = update.message event = update.message
elif update.edited_message: elif update.edited_message:
update_type = "edited_message" update_type = "edited_message"
from_user = update.edited_message.from_user
chat = update.edited_message.chat
event = update.edited_message event = update.edited_message
elif update.channel_post: elif update.channel_post:
update_type = "channel_post" update_type = "channel_post"
chat = update.channel_post.chat
event = update.channel_post event = update.channel_post
elif update.edited_channel_post: elif update.edited_channel_post:
update_type = "edited_channel_post" update_type = "edited_channel_post"
chat = update.edited_channel_post.chat
event = update.edited_channel_post event = update.edited_channel_post
elif update.inline_query: elif update.inline_query:
update_type = "inline_query" update_type = "inline_query"
from_user = update.inline_query.from_user
event = update.inline_query event = update.inline_query
elif update.chosen_inline_result: elif update.chosen_inline_result:
update_type = "chosen_inline_result" update_type = "chosen_inline_result"
from_user = update.chosen_inline_result.from_user
event = update.chosen_inline_result event = update.chosen_inline_result
elif update.callback_query: elif update.callback_query:
update_type = "callback_query" update_type = "callback_query"
if update.callback_query.message:
chat = update.callback_query.message.chat
from_user = update.callback_query.from_user
event = update.callback_query event = update.callback_query
elif update.shipping_query: elif update.shipping_query:
update_type = "shipping_query" update_type = "shipping_query"
from_user = update.shipping_query.from_user
event = update.shipping_query event = update.shipping_query
elif update.pre_checkout_query: elif update.pre_checkout_query:
update_type = "pre_checkout_query" update_type = "pre_checkout_query"
from_user = update.pre_checkout_query.from_user
event = update.pre_checkout_query event = update.pre_checkout_query
elif update.poll: elif update.poll:
update_type = "poll" update_type = "poll"
event = update.poll event = update.poll
elif update.poll_answer:
update_type = "poll_answer"
event = update.poll_answer
else: else:
warnings.warn( warnings.warn(
"Detected unknown update type.\n" "Detected unknown update type.\n"
@ -232,76 +210,17 @@ class Router:
) )
raise SkipHandler raise SkipHandler
return await self.listen_update(
update_type=update_type,
update=update,
event=event,
from_user=from_user,
chat=chat,
**kwargs,
)
async def listen_update(
self,
update_type: str,
update: Update,
event: TelegramObject,
from_user: Optional[User] = None,
chat: Optional[Chat] = None,
**kwargs: Any,
) -> Any:
"""
Listen update by current and child routers
:param update_type:
:param update:
:param event:
:param from_user:
:param chat:
:param kwargs:
:return:
"""
user_token = None
if from_user:
user_token = User.set_current(from_user)
chat_token = None
if chat:
chat_token = Chat.set_current(chat)
kwargs.update(event_update=update, event_router=self) kwargs.update(event_update=update, event_router=self)
observer = self.observers[update_type] observer = self.observers[update_type]
try: response = await observer.trigger(event, update=update, **kwargs)
async for result in observer.trigger(event, update=update, **kwargs):
return result
if response is NOT_HANDLED: # Resolve nested routers
for router in self.sub_routers: for router in self.sub_routers:
try: response = await router.update.trigger(event=update, **kwargs)
return await router.listen_update( if response is NOT_HANDLED:
update_type=update_type,
update=update,
event=event,
from_user=from_user,
chat=chat,
**kwargs,
)
except SkipHandler:
continue continue
raise SkipHandler return response
except SkipHandler:
raise
except Exception as e:
async for result in self.errors.trigger(e, **kwargs):
return result
raise
finally:
if user_token:
User.reset_current(user_token)
if chat_token:
Chat.reset_current(chat_token)
async def emit_startup(self, *args: Any, **kwargs: Any) -> None: async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
""" """
@ -312,8 +231,7 @@ class Router:
:return: :return:
""" """
kwargs.update(router=self) kwargs.update(router=self)
async for _ in self.startup.trigger(*args, **kwargs): # pragma: no cover await self.startup.trigger(*args, **kwargs)
pass
for router in self.sub_routers: for router in self.sub_routers:
await router.emit_startup(*args, **kwargs) await router.emit_startup(*args, **kwargs)
@ -326,8 +244,7 @@ class Router:
:return: :return:
""" """
kwargs.update(router=self) kwargs.update(router=self)
async for _ in self.shutdown.trigger(*args, **kwargs): # pragma: no cover await self.shutdown.trigger(*args, **kwargs)
pass
for router in self.sub_routers: for router in self.sub_routers:
await router.emit_shutdown(*args, **kwargs) await router.emit_shutdown(*args, **kwargs)

View file

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
import contextvars import contextvars
from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
if TYPE_CHECKING: # pragma: no cover
from typing_extensions import Literal from typing_extensions import Literal
__all__ = ("ContextInstanceMixin", "DataMixin") __all__ = ("ContextInstanceMixin", "DataMixin")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 51 KiB

Before After
Before After

View file

@ -39,7 +39,7 @@ dp.include_router(router1)
## Handling updates ## Handling updates
All updates can be propagated to the dispatcher by `feed_update` method: All updates can be propagated to the dispatcher by `feed_update` method:
``` ```python3
bot = Bot(...) bot = Bot(...)
dp = Dispathcher() dp = Dispathcher()

View file

@ -0,0 +1,95 @@
# Middlewares
**aiogram** provides powerful mechanism for customizing event handlers via middlewares.
Middlewares in bot framework seems like Middlewares mechanism in web-frameworks
(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares),
[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/),
[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.)
with small difference - here is implemented two layers of middlewares (before and after filters).
!!! info
Middleware is function that triggered on every event received from
Telegram Bot API in many points on processing pipeline.
## Base theory
As many books and other literature in internet says:
> Middleware is reusable software that leverages patterns and frameworks to bridge
> the gap between the functional requirements of applications and the underlying operating systems,
> network protocol stacks, and databases.
Middleware can modify, extend or reject processing event in many places of pipeline.
## Basics
Middleware instance can be applied for every type of Telegram Event (Update, Message, etc.) in two places
1. Outer scope - before processing filters (`#!python <router>.<event>.outer_middleware(...)`)
2. Inner scope - after processing filters but before handler (`#!python <router>.<event>.middleware(...)`)
[![middlewares](../assets/images/basics_middleware.png)](../assets/images/basics_middleware.png)
_(Click on image to zoom it)_
!!! warning
Middleware should be subclass of `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) or any async callable
## Arguments specification
| Argument | Type | Description |
| - | - | - |
| `handler` | `#!python Callable[[T, Dict[str, Any]], Awaitable[Any]]` | Wrapped handler in middlewares chain |
| `event` | `#!python T` | Incoming event (Subclass of `TelegramObject`) |
| `data` | `#!python Dict[str, Any]` | Contextual data. Will be mapped to handler arguments |
## Examples
!!! danger
Middleware should always call `#!python await handler(event, data)` to propagate event for next middleware/handler
### Class-based
```python3
from aiogram import BaseMiddleware
from aiogram.api.types import Message
class CounterMiddleware(BaseMiddleware[Message]):
def __init__(self) -> None:
self.counter = 0
async def __call__(
self,
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
event: Message,
data: Dict[str, Any]
) -> Any:
self.counter += 1
data['counter'] = self.counter
return await handler(event, data)
```
and then
```python3
router = Router()
router.message.middleware(CounterMiddleware())
```
### Function-based
```python3
@dispatcher.update.outer_middleware()
async def database_transaction_middleware(
handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]],
event: Update,
data: Dict[str, Any]
) -> Any:
async with database.transaction():
return await handler(event, data)
```
## Facts
1. Middlewares from outer scope will be called on every incoming event
1. Middlewares from inner scope will be called only when filters pass
1. Inner middlewares is always calls for `Update` event type in due to all incoming updates going to specific event type handler through built in update handler

View file

@ -1,115 +0,0 @@
# Basics
All middlewares should be made with `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) as base class.
For example:
```python3
class MyMiddleware(BaseMiddleware): ...
```
And then use next pattern in naming callback functions in middleware: `on_{step}_{event}`
Where is:
- `#!python3 step`:
- `#!python3 pre_process`
- `#!python3 process`
- `#!python3 post_process`
- `#!python3 event`:
- `#!python3 update`
- `#!python3 message`
- `#!python3 edited_message`
- `#!python3 channel_post`
- `#!python3 edited_channel_post`
- `#!python3 inline_query`
- `#!python3 chosen_inline_result`
- `#!python3 callback_query`
- `#!python3 shipping_query`
- `#!python3 pre_checkout_query`
- `#!python3 poll`
- `#!python3 poll_answer`
- `#!python3 error`
## Connecting middleware with router
Middlewares can be connected with router by next ways:
1. `#!python3 router.use(MyMiddleware())` (**recommended**)
1. `#!python3 router.middleware.setup(MyMiddleware())`
1. `#!python3 MyMiddleware().setup(router.middleware)` (**not recommended**)
!!! warning
One instance of middleware **can't** be registered twice in single or many middleware managers
## The specification of step callbacks
### Pre-process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
Returns `#!python3 Any`
### Process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
Returns `#!python3 Any`
### Post-Process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
| `#!python3 result` | `#!python3 Dict[str, Any]` | Response from handlers |
Returns `#!python3 Any`
## Full list of available callbacks
- `#!python3 on_pre_process_update` - will be triggered on **pre process** `#!python3 update` event
- `#!python3 on_process_update` - will be triggered on **process** `#!python3 update` event
- `#!python3 on_post_process_update` - will be triggered on **post process** `#!python3 update` event
- `#!python3 on_pre_process_message` - will be triggered on **pre process** `#!python3 message` event
- `#!python3 on_process_message` - will be triggered on **process** `#!python3 message` event
- `#!python3 on_post_process_message` - will be triggered on **post process** `#!python3 message` event
- `#!python3 on_pre_process_edited_message` - will be triggered on **pre process** `#!python3 edited_message` event
- `#!python3 on_process_edited_message` - will be triggered on **process** `#!python3 edited_message` event
- `#!python3 on_post_process_edited_message` - will be triggered on **post process** `#!python3 edited_message` event
- `#!python3 on_pre_process_channel_post` - will be triggered on **pre process** `#!python3 channel_post` event
- `#!python3 on_process_channel_post` - will be triggered on **process** `#!python3 channel_post` event
- `#!python3 on_post_process_channel_post` - will be triggered on **post process** `#!python3 channel_post` event
- `#!python3 on_pre_process_edited_channel_post` - will be triggered on **pre process** `#!python3 edited_channel_post` event
- `#!python3 on_process_edited_channel_post` - will be triggered on **process** `#!python3 edited_channel_post` event
- `#!python3 on_post_process_edited_channel_post` - will be triggered on **post process** `#!python3 edited_channel_post` event
- `#!python3 on_pre_process_inline_query` - will be triggered on **pre process** `#!python3 inline_query` event
- `#!python3 on_process_inline_query` - will be triggered on **process** `#!python3 inline_query` event
- `#!python3 on_post_process_inline_query` - will be triggered on **post process** `#!python3 inline_query` event
- `#!python3 on_pre_process_chosen_inline_result` - will be triggered on **pre process** `#!python3 chosen_inline_result` event
- `#!python3 on_process_chosen_inline_result` - will be triggered on **process** `#!python3 chosen_inline_result` event
- `#!python3 on_post_process_chosen_inline_result` - will be triggered on **post process** `#!python3 chosen_inline_result` event
- `#!python3 on_pre_process_callback_query` - will be triggered on **pre process** `#!python3 callback_query` event
- `#!python3 on_process_callback_query` - will be triggered on **process** `#!python3 callback_query` event
- `#!python3 on_post_process_callback_query` - will be triggered on **post process** `#!python3 callback_query` event
- `#!python3 on_pre_process_shipping_query` - will be triggered on **pre process** `#!python3 shipping_query` event
- `#!python3 on_process_shipping_query` - will be triggered on **process** `#!python3 shipping_query` event
- `#!python3 on_post_process_shipping_query` - will be triggered on **post process** `#!python3 shipping_query` event
- `#!python3 on_pre_process_pre_checkout_query` - will be triggered on **pre process** `#!python3 pre_checkout_query` event
- `#!python3 on_process_pre_checkout_query` - will be triggered on **process** `#!python3 pre_checkout_query` event
- `#!python3 on_post_process_pre_checkout_query` - will be triggered on **post process** `#!python3 pre_checkout_query` event
- `#!python3 on_pre_process_poll` - will be triggered on **pre process** `#!python3 poll` event
- `#!python3 on_process_poll` - will be triggered on **process** `#!python3 poll` event
- `#!python3 on_post_process_poll` - will be triggered on **post process** `#!python3 poll` event
- `#!python3 on_pre_process_poll_answer` - will be triggered on **pre process** `#!python3 poll_answer` event
- `#!python3 on_process_poll_answer` - will be triggered on **process** `#!python3 poll_answer` event
- `#!python3 on_post_process_poll_answer` - will be triggered on **post process** `#!python3 poll_answer` event
- `#!python3 on_pre_process_error` - will be triggered on **pre process** `#!python3 error` event
- `#!python3 on_process_error` - will be triggered on **process** `#!python3 error` event
- `#!python3 on_post_process_error` - will be triggered on **post process** `#!python3 error` event

View file

@ -1,77 +0,0 @@
# Overview
**aiogram** provides powerful mechanism for customizing event handlers via middlewares.
Middlewares in bot framework seems like Middlewares mechanism in web-frameworks
(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares),
[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/),
[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.)
with small difference - here is implemented many layers of processing
(named as [pipeline](#event-pipeline)).
!!! info
Middleware is function that triggered on every event received from
Telegram Bot API in many points on processing pipeline.
## Base theory
As many books and other literature in internet says:
> Middleware is reusable software that leverages patterns and frameworks to bridge
>the gap between the functional requirements of applications and the underlying operating systems,
> network protocol stacks, and databases.
Middleware can modify, extend or reject processing event before-,
on- or after- processing of that event.
[![middlewares](../../assets/images/basics_middleware.png)](../../assets/images/basics_middleware.png)
_(Click on image to zoom it)_
## Event pipeline
As described below middleware an interact with event in many stages of pipeline.
Simple workflow:
1. Dispatcher receive an [Update](../../api/types/update.md)
1. Call **pre-process** update middleware in all routers tree
1. Filter Update over handlers
1. Call **process** update middleware in all routers tree
1. Router detects event type (Message, Callback query, etc.)
1. Router triggers **pre-process** <event> middleware of specific type
1. Pass event over [filters](../filters/index.md) to detect specific handler
1. Call **process** <event> middleware for specific type (only when handler for this event exists)
1. *Do magick*. Call handler (Read more [Event observers](../router.md#event-observers))
1. Call **post-process** <event> middleware
1. Call **post-process** update middleware in all routers tree
1. Emit response into webhook (when it needed)
!!! warning
When filters does not match any handler with this event the `#!python3 process`
step will not be called.
!!! warning
When exception will be caused in handlers pipeline will be stopped immediately
and then start processing error via errors handler and it own middleware callbacks.
!!! warning
Middlewares for updates will be called for all routers in tree but callbacks for events
will be called only for specific branch of routers.
### Pipeline in pictures:
#### Simple pipeline
[![middlewares](../../assets/images/middleware_pipeline.png)](../../assets/images/middleware_pipeline.png)
_(Click on image to zoom it)_
#### Nested routers pipeline
[![middlewares](../../assets/images/middleware_pipeline_nested.png)](../../assets/images/middleware_pipeline_nested.png)
_(Click on image to zoom it)_
## Read more
- [Middleware Basics](basics.md)

View file

@ -20,7 +20,7 @@ Documentation for version 3.0 [WIP] [^1]
- [Supports Telegram Bot API v{!_api_version.md!}](api/index.md) - [Supports Telegram Bot API v{!_api_version.md!}](api/index.md)
- [Updates router](dispatcher/index.md) (Blueprints) - [Updates router](dispatcher/index.md) (Blueprints)
- Finite State Machine - Finite State Machine
- [Middlewares](dispatcher/middlewares/index.md) - [Middlewares](dispatcher/middlewares.md)
- [Replies into Webhook](https://core.telegram.org/bots/faq#how-can-i-make-requests-in-response-to-updates) - [Replies into Webhook](https://core.telegram.org/bots/faq#how-can-i-make-requests-in-response-to-updates)

View file

@ -0,0 +1,12 @@
@font-face {
font-family: 'JetBrainsMono';
src: url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff2/JetBrainsMono-Regular.woff2') format('woff2'),
url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff/JetBrainsMono-Regular.woff') format('woff'),
url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/ttf/JetBrainsMono-Regular.ttf') format('truetype');
font-weight: 400;
font-style: normal;
}
code, kbd, pre {
font-family: "JetBrainsMono", "Roboto Mono", "Courier New", Courier, monospace;
}

View file

@ -16,6 +16,9 @@ theme:
favicon: 'assets/images/logo.png' favicon: 'assets/images/logo.png'
logo: 'assets/images/logo.png' logo: 'assets/images/logo.png'
extra_css:
- stylesheets/extra.css
extra: extra:
version: 3.0.0a3 version: 3.0.0a3
@ -255,9 +258,8 @@ nav:
- dispatcher/class_based_handlers/pre_checkout_query.md - dispatcher/class_based_handlers/pre_checkout_query.md
- dispatcher/class_based_handlers/shipping_query.md - dispatcher/class_based_handlers/shipping_query.md
- dispatcher/class_based_handlers/error.md - dispatcher/class_based_handlers/error.md
- Middlewares: - dispatcher/middlewares.md
- dispatcher/middlewares/index.md
- dispatcher/middlewares/basics.md
- todo.md - todo.md
- Build reports: - Build reports:
- reports.md - reports.md

12
poetry.lock generated
View file

@ -884,7 +884,7 @@ category = "dev"
description = "YAML parser and emitter for Python" description = "YAML parser and emitter for Python"
name = "pyyaml" name = "pyyaml"
optional = false optional = false
python-versions = "*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "5.3.1" version = "5.3.1"
[[package]] [[package]]
@ -955,7 +955,7 @@ python-versions = "*"
version = "1.4.1" version = "1.4.1"
[[package]] [[package]]
category = "main" category = "dev"
description = "Backported and Experimental Type Hints for Python 3.5+" description = "Backported and Experimental Type Hints for Python 3.5+"
name = "typing-extensions" name = "typing-extensions"
optional = false optional = false
@ -1031,8 +1031,7 @@ fast = ["uvloop"]
proxy = ["aiohttp-socks"] proxy = ["aiohttp-socks"]
[metadata] [metadata]
content-hash = "57137b60a539ba01e8df533db976e2f3eadec37e717cbefbe775dc021a8c2714" content-hash = "768759359beca8b84811bfc21adac9649925cd22b87427a10608c9d1e16a0923"
python-versions = "^3.7" python-versions = "^3.7"
[metadata.files] [metadata.files]
@ -1284,11 +1283,6 @@ markupsafe = [
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"},
{file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"},
{file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"},
] ]
mccabe = [ mccabe = [

View file

@ -40,7 +40,6 @@ aiofiles = "^0.4.0"
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true} uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true}
async_lru = "^1.0" async_lru = "^1.0"
aiohttp-socks = {version = "^0.3.8", optional = true} aiohttp-socks = {version = "^0.3.8", optional = true}
typing-extensions = "^3.7.4"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'"} uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'"}
@ -68,6 +67,7 @@ markdown-include = "^0.5.1"
aiohttp-socks = "^0.3.4" aiohttp-socks = "^0.3.4"
pre-commit = "^2.3.0" pre-commit = "^2.3.0"
packaging = "^20.3" packaging = "^20.3"
typing-extensions = "^3.7.4"
[tool.poetry.extras] [tool.poetry.extras]
fast = ["uvloop"] fast = ["uvloop"]
@ -94,7 +94,7 @@ include_trailing_comma = true
force_grid_wrap = 0 force_grid_wrap = 0
use_parentheses = true use_parentheses = true
line_length = 99 line_length = 99
known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pydantic", "pytest", "typing_extensions"] known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pkg_resources", "pydantic", "pytest"]
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry>=0.12"]

View file

@ -11,7 +11,13 @@ class TestInlineQuery:
offset="", offset="",
) )
kwargs = dict(results=[], cache_time=123, next_offset="123", switch_pm_text="foo", switch_pm_parameter="foo") kwargs = dict(
results=[],
cache_time=123,
next_offset="123",
switch_pm_text="foo",
switch_pm_parameter="foo",
)
api_method = inline_query.answer(**kwargs) api_method = inline_query.answer(**kwargs)

View file

@ -1,5 +1,5 @@
from aiogram.api.methods import AnswerShippingQuery from aiogram.api.methods import AnswerShippingQuery
from aiogram.api.types import ShippingAddress, ShippingQuery, User, ShippingOption, LabeledPrice from aiogram.api.types import LabeledPrice, ShippingAddress, ShippingOption, ShippingQuery, User
class TestInlineQuery: class TestInlineQuery:
@ -19,7 +19,8 @@ class TestInlineQuery:
) )
shipping_options = [ shipping_options = [
ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])] ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])
]
kwargs = dict(ok=True, shipping_options=shipping_options, error_message="foo") kwargs = dict(ok=True, shipping_options=shipping_options, error_message="foo")

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aiogram.dispatcher.event.observer import TelegramEventObserver from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from tests.deprecated import check_deprecated from tests.deprecated import check_deprecated

View file

@ -9,6 +9,7 @@ from aiogram import Bot
from aiogram.api.methods import GetMe, GetUpdates, SendMessage from aiogram.api.methods import GetMe, GetUpdates, SendMessage
from aiogram.api.types import Chat, Message, Update, User from aiogram.api.types import Chat, Message, Update, User
from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NOT_HANDLED
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
@ -63,7 +64,7 @@ class TestDispatcher:
return message.text return message.text
results_count = 0 results_count = 0
async for result in dp.feed_update( result = await dp.feed_update(
bot=bot, bot=bot,
update=Update( update=Update(
update_id=42, update_id=42,
@ -75,12 +76,10 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"), from_user=User(id=42, is_bot=False, first_name="Test"),
), ),
), ),
): )
results_count += 1 results_count += 1
assert result == "test" assert result == "test"
assert results_count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_feed_raw_update(self): async def test_feed_raw_update(self):
dp = Dispatcher() dp = Dispatcher()
@ -91,8 +90,7 @@ class TestDispatcher:
assert message.text == "test" assert message.text == "test"
return message.text return message.text
handled = False result = await dp.feed_raw_update(
async for result in dp.feed_raw_update(
bot=bot, bot=bot,
update={ update={
"update_id": 42, "update_id": 42,
@ -101,13 +99,11 @@ class TestDispatcher:
"date": int(time.time()), "date": int(time.time()),
"text": "test", "text": "test",
"chat": {"id": 42, "type": "private"}, "chat": {"id": 42, "type": "private"},
"user": {"id": 42, "is_bot": False, "first_name": "Test"}, "from": {"id": 42, "is_bot": False, "first_name": "Test"},
}, },
}, },
): )
handled = True
assert result == "test" assert result == "test"
assert handled
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_updates(self, bot: MockedBot): async def test_listen_updates(self, bot: MockedBot):
@ -136,7 +132,8 @@ class TestDispatcher:
async def test_process_update_empty(self, bot: MockedBot): async def test_process_update_empty(self, bot: MockedBot):
dispatcher = Dispatcher() dispatcher = Dispatcher()
assert not await dispatcher.process_update(bot=bot, update=Update(update_id=42)) result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
assert result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_update_handled(self, bot: MockedBot): async def test_process_update_handled(self, bot: MockedBot):
@ -146,22 +143,25 @@ class TestDispatcher:
async def update_handler(update: Update): async def update_handler(update: Update):
pass pass
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_update_call_request(self, bot: MockedBot): async def test_process_update_call_request(self, bot: MockedBot):
dispatcher = Dispatcher() dispatcher = Dispatcher()
@dispatcher.update() @dispatcher.update()
async def update_handler(update: Update): async def message_handler(update: Update):
return GetMe() return GetMe()
dispatcher.update.handlers.reverse()
with patch( with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request", "aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request",
new_callable=CoroutineMock, new_callable=CoroutineMock,
) as mocked_silent_call_request: ) as mocked_silent_call_request:
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
mocked_silent_call_request.assert_awaited_once() print(result)
mocked_silent_call_request.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_update_exception(self, bot: MockedBot, caplog): async def test_process_update_exception(self, bot: MockedBot, caplog):
@ -171,7 +171,7 @@ class TestDispatcher:
async def update_handler(update: Update): async def update_handler(update: Update):
raise Exception("Kaboom!") raise Exception("Kaboom!")
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
log_records = [rec.message for rec in caplog.records] log_records = [rec.message for rec in caplog.records]
assert len(log_records) == 1 assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0] assert "Cause exception while process update" in log_records[0]
@ -184,7 +184,7 @@ class TestDispatcher:
yield Update(update_id=42) yield Update(update_id=42)
with patch( with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch( ) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates" "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates: ) as patched_listen_updates:
@ -203,7 +203,7 @@ class TestDispatcher:
yield Update(update_id=42) yield Update(update_id=42)
with patch( with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch( ) as mocked_process_update, patch(
"aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock "aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock
) as mocked_emit_startup, patch( ) as mocked_emit_startup, patch(

View file

@ -0,0 +1,59 @@
import functools
from typing import Any
import pytest
from aiogram.dispatcher.event.event import EventObserver
from aiogram.dispatcher.event.handler import HandlerObject
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
async def my_handler(value: str, index: int = 0) -> Any:
return value
class TestEventObserver:
@pytest.mark.parametrize("via_decorator", [True, False])
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, via_decorator, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
if via_decorator:
register_result = observer()(wrapped_handler)
assert register_result == wrapped_handler
else:
register_result = observer.register(wrapped_handler)
assert register_result is None
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger(self):
observer = EventObserver()
observer.register(my_handler)
observer.register(lambda e: True)
observer.register(my_handler)
assert observer.handlers[0].awaitable
assert not observer.handlers[1].awaitable
assert observer.handlers[2].awaitable
with patch(
"aiogram.dispatcher.event.handler.CallableMixin.call", new_callable=CoroutineMock,
) as mocked_my_handler:
results = await observer.trigger("test")
assert results is None
mocked_my_handler.assert_awaited_with("test")
assert mocked_my_handler.call_count == 3

View file

@ -5,11 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
import pytest import pytest
from aiogram.api.types import Chat, Message, User from aiogram.api.types import Chat, Message, User
from aiogram.dispatcher.event.bases import SkipHandler
from aiogram.dispatcher.event.handler import HandlerObject from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
# TODO: Test middlewares in routers tree
async def my_handler(event: Any, index: int = 0) -> Any: async def my_handler(event: Any, index: int = 0) -> Any:
return event return event
@ -38,54 +41,6 @@ class MyFilter3(MyFilter1):
pass pass
class TestEventObserver:
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer.register(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters_via_decorator(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer()(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger_accepted_bool(self):
observer = EventObserver()
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@pytest.mark.asyncio
async def test_trigger_with_skip(self):
observer = EventObserver()
observer.register(skip_my_handler)
observer.register(my_handler)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42, 42]
class TestTelegramEventObserver: class TestTelegramEventObserver:
def test_bind_filter(self): def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test") event_observer = TelegramEventObserver(Router(), "test")
@ -198,8 +153,8 @@ class TestTelegramEventObserver:
from_user=User(id=42, is_bot=False, first_name="Test"), from_user=User(id=42, is_bot=False, first_name="Test"),
) )
results = [result async for result in observer.trigger(message)] results = await observer.trigger(message)
assert results == [message] assert results is message
@pytest.mark.parametrize( @pytest.mark.parametrize(
"count,handler,filters", "count,handler,filters",
@ -223,15 +178,58 @@ class TestTelegramEventObserver:
assert registered_handler.callback == wrapped_handler assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters) assert len(registered_handler.filters) == len(filters)
#
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self): async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False) router = Router(use_builtin_filters=False)
observer = router.message observer = router.message
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)] async def mix_unnecessary_data(event):
assert results == [((42,), {"b": 2})] return {"a": 1}
async def mix_data(event):
return {"b": 2}
async def handler(event, **kwargs):
return False
observer.register(
pipe_handler, mix_unnecessary_data, handler
) # {"a": 1} should not be in result
observer.register(pipe_handler, mix_data)
results = await observer.trigger(42)
assert results == ((42,), {"b": 2})
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
def test_register_middleware(self, middleware_type):
event_observer = TelegramEventObserver(Router(), "test")
middlewares = getattr(event_observer, f"{middleware_type}s")
decorator = getattr(event_observer, middleware_type)
@decorator
async def my_middleware1(handler, event, data):
pass
assert my_middleware1 is not None
assert my_middleware1.__name__ == "my_middleware1"
assert my_middleware1 in middlewares
@decorator()
async def my_middleware2(handler, event, data):
pass
assert my_middleware2 is not None
assert my_middleware2.__name__ == "my_middleware2"
assert my_middleware2 in middlewares
async def my_middleware3(handler, event, data):
pass
decorator(my_middleware3)
assert my_middleware3 is not None
assert my_middleware3.__name__ == "my_middleware3"
assert my_middleware3 in middlewares
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]

View file

@ -1,257 +0,0 @@
import datetime
from typing import Any, Dict, Type
import pytest
from aiogram.api.types import (
CallbackQuery,
Chat,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
User,
)
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
class MyMiddleware(BaseMiddleware):
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
return "channel_post"
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
return "poll_answer"
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
return "channel_post"
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
return "poll_answer"
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
return "update"
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "message"
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_message"
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "channel_post"
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_channel_post"
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
return "inline_query"
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
return "chosen_inline_result"
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
return "callback_query"
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
return "shipping_query"
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
return "pre_checkout_query"
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
return "poll"
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
return "poll_answer"
async def on_post_process_error(
self, exception: Exception, data: Dict[str, Any], result: Any
) -> Any:
return "error"
UPDATE = Update(update_id=42)
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
POLL_ANSWER = PollAnswer(
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
)
class TestBaseMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
)
@pytest.mark.parametrize(
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
)
@pytest.mark.parametrize(
"event_name,event",
[
["update", UPDATE],
["message", MESSAGE],
["poll_answer", POLL_ANSWER],
["error", Exception("KABOOM")],
],
)
async def test_trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
middleware_cls: Type[BaseMiddleware],
should_be_awaited: bool,
):
middleware = middleware_cls()
with patch(
f"tests.test_dispatcher.test_middlewares.test_base."
f"MyMiddleware.on_{step.value}_{event_name}",
new_callable=CoroutineMock,
) as mocked_call:
response = await middleware.trigger(
step=step, event_name=event_name, event=event, data={}
)
if should_be_awaited:
mocked_call.assert_awaited()
assert response is not None
else:
mocked_call.assert_not_awaited()
assert response is None
def test_not_configured(self):
middleware = BaseMiddleware()
assert not middleware.configured
with pytest.raises(RuntimeError):
manager = middleware.manager

View file

@ -1,82 +0,0 @@
import pytest
from aiogram import Router
from aiogram.api.types import Update
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.dispatcher.middlewares.types import MiddlewareStep
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
@pytest.fixture("function")
def router():
return Router()
@pytest.fixture("function")
def manager(router: Router):
return MiddlewareManager(router)
class TestManager:
def test_setup(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
returned = manager.setup(middleware)
assert returned is middleware
assert middleware.configured
assert middleware.manager is manager
assert middleware in manager
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
with pytest.raises(TypeError):
assert manager.setup(obj)
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
new_manager = MiddlewareManager(router)
with pytest.raises(ValueError):
new_manager.setup(middleware)
with pytest.raises(ValueError):
middleware.setup(new_manager)
def test_configure_twice(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
manager.setup(middleware)
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
middleware.setup(manager)
@pytest.mark.asyncio
@pytest.mark.parametrize("count", range(5))
async def test_trigger(self, manager: MiddlewareManager, count: int):
for _ in range(count):
manager.setup(BaseMiddleware())
with patch(
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
new_callable=CoroutineMock,
) as mocked_call:
await manager.trigger(
step=MiddlewareStep.PROCESS,
event_name="update",
event=Update(update_id=42),
data={},
result=None,
reverse=True,
)
assert mocked_call.await_count == count

View file

@ -10,6 +10,7 @@ from aiogram.api.types import (
InlineQuery, InlineQuery,
Message, Message,
Poll, Poll,
PollAnswer,
PollOption, PollOption,
PreCheckoutQuery, PreCheckoutQuery,
ShippingAddress, ShippingAddress,
@ -17,8 +18,8 @@ from aiogram.api.types import (
Update, Update,
User, User,
) )
from aiogram.dispatcher.event.observer import SkipHandler, skip from aiogram.dispatcher.event.bases import NOT_HANDLED, SkipHandler, skip
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect from aiogram.utils.warnings import CodeHasNoEffect
@ -274,12 +275,26 @@ class TestRouter:
False, False,
False, False,
), ),
pytest.param(
"poll_answer",
Update(
update_id=42,
poll_answer=PollAnswer(
poll_id="poll id",
user=User(id=42, is_bot=False, first_name="Test"),
option_ids=[42],
),
),
False,
True,
),
], ],
) )
async def test_listen_update( async def test_listen_update(
self, event_type: str, update: Update, has_chat: bool, has_user: bool self, event_type: str, update: Update, has_chat: bool, has_user: bool
): ):
router = Router() router = Router()
router.update.outer_middleware(UserContextMiddleware())
observer = router.observers[event_type] observer = router.observers[event_type]
@observer() @observer()
@ -291,7 +306,7 @@ class TestRouter:
assert User.get_current(False) assert User.get_current(False)
return kwargs return kwargs
result = await router._listen_update(update, test="PASS") result = await router.update.trigger(update, test="PASS")
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["event_update"] == update assert result["event_update"] == update
assert result["event_router"] == router assert result["event_router"] == router
@ -313,8 +328,7 @@ class TestRouter:
async def handler(event: Any): async def handler(event: Any):
pass pass
with pytest.raises(SkipHandler): response = await router._listen_update(
await router._listen_update(
Update( Update(
update_id=42, update_id=42,
poll=Poll( poll=Poll(
@ -333,6 +347,7 @@ class TestRouter:
), ),
) )
) )
assert response is NOT_HANDLED
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_nested_router_listen_update(self): async def test_nested_router_listen_update(self):
@ -345,8 +360,6 @@ class TestRouter:
@observer() @observer()
async def my_handler(event: Message, **kwargs: Any): async def my_handler(event: Message, **kwargs: Any):
assert Chat.get_current(False)
assert User.get_current(False)
return kwargs return kwargs
update = Update( update = Update(
@ -409,14 +422,6 @@ class TestRouter:
await router1.emit_shutdown() await router1.emit_shutdown()
assert results == [2, 1, 2] assert results == [2, 1, 2]
def test_use(self):
router = Router()
middleware = router.use(BaseMiddleware())
assert isinstance(middleware, BaseMiddleware)
assert middleware.configured
assert middleware.manager == router.middleware
def test_skip(self): def test_skip(self):
with pytest.raises(SkipHandler): with pytest.raises(SkipHandler):
skip() skip()
@ -444,37 +449,20 @@ class TestRouter:
), ),
) )
with pytest.raises(Exception, match="KABOOM"): with pytest.raises(Exception, match="KABOOM"):
await root_router.listen_update( await root_router.update.trigger(update)
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
@root_router.errors() @root_router.errors()
async def root_error_handler(exception: Exception): async def root_error_handler(event: Update, exception: Exception):
return exception return exception
response = await root_router.listen_update( response = await root_router.update.trigger(update)
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
assert isinstance(response, Exception) assert isinstance(response, Exception)
assert str(response) == "KABOOM" assert str(response) == "KABOOM"
@router.errors() @router.errors()
async def error_handler(exception: Exception): async def error_handler(event: Update, exception: Exception):
return "KABOOM" return "KABOOM"
response = await root_router.listen_update( response = await root_router.update.trigger(update)
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
assert response == "KABOOM" assert response == "KABOOM"