Merge pull request #337 from aiogram/dev-3.x-middlewares

Implement new middlewares
This commit is contained in:
Alex Root Junior 2020-05-26 22:06:21 +03:00 committed by GitHub
commit a627c75bab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 618 additions and 1434 deletions

View file

@ -6,7 +6,6 @@ python := $(py) python
reports_dir := reports
.PHONY: help
help:
@echo "======================================================================================="
@echo " aiogram build tools "
@ -45,12 +44,10 @@ help:
# Environment
# =================================================================================================
.PHONY: install
install:
$(base_python) -m pip install --user -U poetry
poetry install
.PHONY: clean
clean:
rm -rf `find . -name __pycache__`
rm -f `find . -type f -name '*.py[co]' `
@ -68,65 +65,56 @@ clean:
# Code quality
# =================================================================================================
.PHONY: isort
isort:
$(py) isort -rc aiogram tests
.PHONY: black
black:
$(py) black aiogram tests
.PHONY: flake8
flake8:
$(py) flake8 aiogram test
.PHONY: flake8-report
flake8-report:
mkdir -p $(reports_dir)/flake8
$(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram test
.PHONY: mypy
mypy:
$(py) mypy aiogram
.PHONY: mypy-report
mypy-report:
$(py) mypy aiogram --html-report $(reports_dir)/typechecking
.PHONY: lint
lint: isort black flake8 mypy
# =================================================================================================
# Tests
# =================================================================================================
.PHONY: test
test:
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/
.PHONY: test-coverage
test-coverage:
mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/
test-coverage-report:
$(py) coverage html -d $(reports_dir)/coverage
.PHONY: test-coverage-report
test-coverage-report:
test-coverage-view:
$(py) coverage html -d $(reports_dir)/coverage
python -c "import webbrowser; webbrowser.open('file://$(shell pwd)/reports/coverage/index.html')"
# =================================================================================================
# Docs
# =================================================================================================
.PHONY: docs
docs:
$(py) mkdocs build
.PHONY: docs-serve
docs-serve:
$(py) mkdocs serve
.PHONY: docs-copy-reports
docs-copy-reports:
mv $(reports_dir)/* site/reports
@ -134,9 +122,7 @@ docs-copy-reports:
# Project
# =================================================================================================
.PHONY: build
build: clean flake8-report mypy-report test-coverage docs docs-copy-reports
mkdir -p site/simple
poetry build
mv dist site/simple/aiogram

View file

@ -1,3 +1,5 @@
from pkg_resources import get_distribution
from .api import methods, types
from .api.client import session
from .api.client.bot import Bot
@ -28,5 +30,5 @@ __all__ = (
"handler",
)
__version__ = "3.0.0a4"
__version__ = get_distribution(dist=__package__).version
__api_version__ = "4.8"

View file

@ -11,6 +11,8 @@ from ..api.client.bot import Bot
from ..api.methods import TelegramMethod
from ..api.types import Update, User
from ..utils.exceptions import TelegramAPIError
from .event.bases import NOT_HANDLED
from .middlewares.user_context import UserContextMiddleware
from .router import Router
@ -23,6 +25,9 @@ class Dispatcher(Router):
super(Dispatcher, self).__init__(**kwargs)
self._running_lock = Lock()
# Default middleware is needed for contextual features
self.update.outer_middleware(UserContextMiddleware())
@property
def parent_router(self) -> None:
"""
@ -42,9 +47,7 @@ class Dispatcher(Router):
"""
raise RuntimeError("Dispatcher can not be attached to another Router.")
async def feed_update(
self, bot: Bot, update: Update, **kwargs: Any
) -> AsyncGenerator[Any, None]:
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
Main entry point for incoming updates
@ -57,9 +60,9 @@ class Dispatcher(Router):
Bot.set_current(bot)
try:
async for result in self.update.trigger(update, bot=bot, **kwargs):
handled = True
yield result
response = await self.update.trigger(update, bot=bot, **kwargs)
handled = response is not NOT_HANDLED
return response
finally:
finish_time = loop.time()
duration = (finish_time - start_time) * 1000
@ -71,9 +74,7 @@ class Dispatcher(Router):
bot.id,
)
async def feed_raw_update(
self, bot: Bot, update: Dict[str, Any], **kwargs: Any
) -> AsyncGenerator[Any, None]:
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
"""
Main entry point for incoming updates with automatic Dict->Update serializer
@ -82,8 +83,7 @@ class Dispatcher(Router):
:param kwargs:
"""
parsed_update = Update(**update)
async for result in self.feed_update(bot=bot, update=parsed_update, **kwargs):
yield result
return await self.feed_update(bot=bot, update=parsed_update, **kwargs)
@classmethod
async def _listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]:
@ -114,7 +114,7 @@ class Dispatcher(Router):
# For debugging here is added logging.
loggers.dispatcher.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
async def process_update(
async def _process_update(
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any
) -> bool:
"""
@ -126,11 +126,13 @@ class Dispatcher(Router):
:param kwargs: contextual data for middlewares, filters and handlers
:return: status
"""
handled = False
try:
async for result in self.feed_update(bot, update, **kwargs):
if call_answer and isinstance(result, TelegramMethod):
await self._silent_call_request(bot=bot, result=result)
return True
response = await self.feed_update(bot, update, **kwargs)
handled = handled is not NOT_HANDLED
if call_answer and isinstance(response, TelegramMethod):
await self._silent_call_request(bot=bot, result=response)
return handled
except Exception as e:
loggers.dispatcher.exception(
@ -142,8 +144,6 @@ class Dispatcher(Router):
)
return True # because update was processed but unsuccessful
return False
async def _polling(self, bot: Bot, **kwargs: Any) -> None:
"""
Internal polling process
@ -153,16 +153,14 @@ class Dispatcher(Router):
:return:
"""
async for update in self._listen_updates(bot):
await self.process_update(bot=bot, update=update, **kwargs)
await self._process_update(bot=bot, update=update, **kwargs)
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
The same with `Dispatcher.process_update()` but returns real response instead of bool
"""
try:
async for result in self.feed_update(bot, update, **kwargs):
return result
return await self.feed_update(bot, update, **kwargs)
except Exception as e:
loggers.dispatcher.exception(
"Cause exception while process update id=%d by bot id=%d\n%s: %s",
@ -196,10 +194,10 @@ class Dispatcher(Router):
def process_response(task: Future[Any]) -> None:
warnings.warn(
f"Detected slow response into webhook.\n"
f"Telegram is waiting for response only first 60 seconds and then re-send update.\n"
f"For preventing this situation response into webhook returned immediately "
f"and handler is moved to background and still processing update.",
"Detected slow response into webhook.\n"
"Telegram is waiting for response only first 60 seconds and then re-send update.\n"
"For preventing this situation response into webhook returned immediately "
"and handler is moved to background and still processing update.",
RuntimeWarning,
)
try:

View file

@ -0,0 +1,29 @@
from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
from unittest.mock import sentinel
from ...api.types import TelegramObject
from ..middlewares.base import BaseMiddleware
NextMiddlewareType = Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]]
MiddlewareType = Union[
BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]]
]
NOT_HANDLED = sentinel.NOT_HANDLED
class SkipHandler(Exception):
pass
class CancelHandler(Exception):
pass
def skip(message: Optional[str] = None) -> NoReturn:
"""
Raise an SkipHandler
"""
raise SkipHandler(message or "Event skipped")

View file

@ -0,0 +1,39 @@
from __future__ import annotations
from typing import Any, Callable, List
from .handler import CallbackType, HandlerObject, HandlerType
class EventObserver:
"""
Simple events observer
"""
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType) -> None:
"""
Register callback with filters
"""
self.handlers.append(HandlerObject(callback=callback))
async def trigger(self, *args: Any, **kwargs: Any) -> None:
"""
Propagate event to handlers.
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
await handler.call(*args, **kwargs)
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback)
return callback
return wrapper

View file

@ -1,3 +1,5 @@
import asyncio
import contextvars
import inspect
from dataclasses import dataclass, field
from functools import partial
@ -6,9 +8,9 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type,
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
CallbackType = Callable[..., Awaitable[Any]]
SyncFilter = Callable[..., Any]
AsyncFilter = Callable[..., Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[FilterType, Type[BaseHandler]]
@ -40,7 +42,11 @@ class CallableMixin:
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
if self.awaitable:
return await wrapped()
return wrapped()
loop = asyncio.get_event_loop()
context = contextvars.copy_context()
wrapped = partial(context.run, wrapped)
return await loop.run_in_executor(None, wrapped)
@dataclass
@ -60,11 +66,11 @@ class HandlerObject(CallableMixin):
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters:
return True, {}
return True, kwargs
for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs)
if not check:
return False, {}
return False, kwargs
if isinstance(check, dict):
kwargs.update(check)
return True, kwargs

View file

@ -1,93 +1,33 @@
from __future__ import annotations
import functools
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
NoReturn,
Optional,
Type,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
from pydantic import ValidationError
from ...api.types import TelegramObject
from ..filters.base import BaseFilter
from ..middlewares.types import MiddlewareStep, UpdateType
from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
class SkipHandler(Exception):
pass
class CancelHandler(Exception):
pass
def skip(message: Optional[str] = None) -> NoReturn:
"""
Raise an SkipHandler
"""
raise SkipHandler(message or "Event skipped")
class EventObserver:
"""
Base events observer
"""
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType) -> HandlerType:
"""
Register callback with filters
"""
self.handlers.append(HandlerObject(callback=callback))
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
"""
Propagate event to handlers.
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback)
return callback
return wrapper
class TelegramEventObserver(EventObserver):
class TelegramEventObserver:
"""
Event observer for Telegram events
"""
def __init__(self, router: Router, event_name: str) -> None:
super().__init__()
self.router: Router = router
self.event_name: str = event_name
self.handlers: List[HandlerObject] = []
self.filters: List[Type[BaseFilter]] = []
self.outer_middlewares: List[MiddlewareType] = []
self.middlewares: List[MiddlewareType] = []
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
"""
@ -144,37 +84,6 @@ class TelegramEventObserver(EventObserver):
return filters
async def trigger_middleware(
self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None,
) -> None:
"""
Trigger middlewares chain
:param step:
:param event:
:param data:
:param result:
:return:
"""
reverse = step == MiddlewareStep.POST_PROCESS
recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS
if self.event_name == "update":
routers = self.router.chain
else:
routers = self.router.chain_head
for router in routers:
await router.middleware.trigger(
step=step,
event_name=self.event_name,
event=event,
data=data,
result=result,
reverse=reverse,
)
if not recursive:
break
def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
) -> HandlerType:
@ -190,32 +99,39 @@ class TelegramEventObserver(EventObserver):
)
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
@classmethod
def _wrap_middleware(
cls, middlewares: List[MiddlewareType], handler: HandlerType
) -> NextMiddlewareType:
@functools.wraps(handler)
def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = mapper
for m in reversed(middlewares):
middleware = functools.partial(m, middleware)
return middleware
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
"""
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
event = args[0]
await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs)
wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger)
return await wrapped_outer(event, kwargs)
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
for handler in self.handlers:
result, data = await handler.check(*args, **kwargs)
result, data = await handler.check(event, **kwargs)
if result:
kwargs.update(data)
await self.trigger_middleware(
step=MiddlewareStep.PROCESS, event=event, data=kwargs
)
try:
response = await handler.call(*args, **kwargs)
await self.trigger_middleware(
step=MiddlewareStep.POST_PROCESS,
event=event,
data=kwargs,
result=response,
)
yield response
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
return await wrapped_inner(event, kwargs)
except SkipHandler:
continue
break
return NOT_HANDLED
def __call__(
self, *args: FilterType, **bound_filters: BaseFilter
@ -229,3 +145,45 @@ class TelegramEventObserver(EventObserver):
return callback
return wrapper
def middleware(
self, middleware: Optional[MiddlewareType] = None,
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
"""
Decorator for registering inner middlewares
Usage:
>>> @<event>.middleware() # via decorator (variant 1)
>>> @<event>.middleware # via decorator (variant 2)
>>> async def my_middleware(handler, event, data): ...
>>> <event>.middleware(middleware) # via method
"""
def wrapper(m: MiddlewareType) -> MiddlewareType:
self.middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)
def outer_middleware(
self, middleware: Optional[MiddlewareType] = None,
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
"""
Decorator for registering outer middlewares
Usage:
>>> @<event>.outer_middleware() # via decorator (variant 1)
>>> @<event>.outer_middleware # via decorator (variant 2)
>>> async def my_middleware(handler, event, data): ...
>>> <event>.outer_middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType) -> MiddlewareType:
self.outer_middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)

View file

@ -1,61 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
class AbstractMiddleware(ABC):
"""
Abstract class for middleware.
"""
def __init__(self) -> None:
self._manager: Optional[MiddlewareManager] = None
@property
def manager(self) -> MiddlewareManager:
"""
Instance of MiddlewareManager
"""
if self._manager is None:
raise RuntimeError("Middleware is not configured!")
return self._manager
def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware:
"""
Mark middleware as configured
:param manager:
:param _stack_level:
:return:
"""
if self.configured:
return manager.setup(self, _stack_level=_stack_level + 1)
self._manager = manager
return self
@property
def configured(self) -> bool:
"""
Check middleware is configured
:return:
"""
return bool(self._manager)
@abstractmethod
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
) -> Any: # pragma: no cover
pass

View file

@ -1,317 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Dict
from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
T = TypeVar("T")
class BaseMiddleware(AbstractMiddleware):
"""
Base class for middleware.
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
"""
async def trigger(
class BaseMiddleware(ABC, Generic[T]):
@abstractmethod
async def __call__(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
handler: Callable[[T, Dict[str, Any]], Awaitable[Any]],
event: T,
data: Dict[str, Any],
result: Any = None,
) -> Any:
"""
Trigger action.
:param step:
:param event_name:
:param event:
:param data:
:param result:
:return:
"""
handler_name = f"on_{step.value}_{event_name}"
handler = getattr(self, handler_name, None)
if not handler:
return None
args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data)
return await handler(*args)
if TYPE_CHECKING: # pragma: no cover
# =============================================================================================
# Event that triggers before process <event>
# =============================================================================================
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process update
"""
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process message
"""
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_message
"""
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process channel_post
"""
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_channel_post
"""
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process inline_query
"""
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process chosen_inline_result
"""
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process callback_query
"""
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process shipping_query
"""
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process pre_checkout_query
"""
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process poll
"""
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process poll_answer
"""
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process error
"""
# =============================================================================================
# Event that triggers on process <event> after filters.
# =============================================================================================
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process update
"""
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process message
"""
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_message
"""
async def on_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process channel_post
"""
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_channel_post
"""
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process inline_query
"""
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process chosen_inline_result
"""
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process callback_query
"""
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process shipping_query
"""
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process pre_checkout_query
"""
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process poll
"""
async def on_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process poll_answer
"""
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process error
"""
# =============================================================================================
# Event that triggers after process <event>.
# =============================================================================================
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing update
"""
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing message
"""
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_message
"""
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing channel_post
"""
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_channel_post
"""
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing inline_query
"""
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing chosen_inline_result
"""
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing callback_query
"""
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing shipping_query
"""
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing pre_checkout_query
"""
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
"""
Event that triggers after processing poll
"""
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing poll_answer
"""
async def on_post_process_error(
self, exception: Exception, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing error
"""
) -> Any: # pragma: no cover
pass

View file

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
from ...api.types import Update
from ..event.bases import NOT_HANDLED, CancelHandler, SkipHandler
from .base import BaseMiddleware
if TYPE_CHECKING: # pragma: no cover
from ..router import Router
class ErrorsMiddleware(BaseMiddleware[Update]):
def __init__(self, router: Router):
self.router = router
async def __call__(
self,
handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]],
event: Any,
data: Dict[str, Any],
) -> Any:
try:
return await handler(event, data)
except (SkipHandler, CancelHandler): # pragma: no cover
raise
except Exception as e:
response = await self.router.errors.trigger(event, exception=e, **data)
if response is NOT_HANDLED:
raise
return response

View file

@ -1,71 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from warnings import warn
from .abstract import AbstractMiddleware
from .types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
class MiddlewareManager:
"""
Middleware manager.
"""
def __init__(self, router: Router) -> None:
self.router = router
self.middlewares: List[AbstractMiddleware] = []
def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Setup middleware
:param middleware:
:param _stack_level:
:return:
"""
if not isinstance(middleware, AbstractMiddleware):
raise TypeError(
f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}"
)
if middleware.configured:
if middleware.manager is self:
warn(
f"Middleware {middleware} is already configured for this Router "
"That's mean re-installing of this middleware has no effect.",
category=RuntimeWarning,
stacklevel=_stack_level + 1,
)
return middleware
raise ValueError(
f"Middleware is already configured for another manager {middleware.manager} "
f"in router {middleware.manager.router}!"
)
self.middlewares.append(middleware)
middleware.setup(self)
return middleware
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
reverse: bool = False,
) -> Any:
"""
Call action to middlewares with args lilt.
"""
middlewares = reversed(self.middlewares) if reverse else self.middlewares
for middleware in middlewares:
await middleware.trigger(
step=step, event_name=event_name, event=event, data=data, result=result
)
def __contains__(self, item: AbstractMiddleware) -> bool:
return item in self.middlewares

View file

@ -1,35 +0,0 @@
from __future__ import annotations
from enum import Enum
from typing import Union
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
UpdateType = Union[
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
BaseException,
]
class MiddlewareStep(Enum):
PRE_PROCESS = "pre_process"
PROCESS = "process"
POST_PROCESS = "post_process"

View file

@ -0,0 +1,62 @@
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
from aiogram.api.types import Chat, Update, User
from aiogram.dispatcher.middlewares.base import BaseMiddleware
class UserContextMiddleware(BaseMiddleware[Update]):
async def __call__(
self,
handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]],
event: Update,
data: Dict[str, Any],
) -> Any:
chat, user = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user):
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
@classmethod
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
"""
Resolve chat and user instance from Update object
"""
if event.message:
return event.message.chat, event.message.from_user
if event.edited_message:
return event.edited_message.chat, event.edited_message.from_user
if event.channel_post:
return event.channel_post.chat, None
if event.edited_channel_post:
return event.edited_channel_post.chat, None
if event.inline_query:
return None, event.inline_query.from_user
if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user
if event.callback_query:
if event.callback_query.message:
return event.callback_query.message.chat, event.callback_query.from_user
return None, event.callback_query.from_user
if event.shipping_query:
return None, event.shipping_query.from_user
if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user
if event.poll_answer:
return None, event.poll_answer.user
return None, None

View file

@ -3,13 +3,14 @@ from __future__ import annotations
import warnings
from typing import Any, Dict, Generator, List, Optional, Union
from ..api.types import Chat, TelegramObject, Update, User
from ..api.types import TelegramObject, Update
from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
from .event.bases import NOT_HANDLED, SkipHandler
from .event.event import EventObserver
from .event.telegram import TelegramEventObserver
from .filters import BUILTIN_FILTERS
from .middlewares.abstract import AbstractMiddleware
from .middlewares.manager import MiddlewareManager
from .middlewares.error import ErrorsMiddleware
class Router:
@ -44,8 +45,6 @@ class Router:
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
self.errors = TelegramEventObserver(router=self, event_name="error")
self.middleware = MiddlewareManager(router=self)
self.startup = EventObserver()
self.shutdown = EventObserver()
@ -68,6 +67,8 @@ class Router:
# Root handler
self.update.register(self._listen_update)
self.update.outer_middleware(ErrorsMiddleware(self))
# Builtin filters
if use_builtin_filters:
for name, observer in self.observers.items():
@ -94,16 +95,6 @@ class Router:
next(tail) # Skip self
yield from tail
def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Use middleware
:param middleware:
:param _stack_level:
:return:
"""
return self.middleware.setup(middleware, _stack_level=_stack_level + 1)
@property
def parent_router(self) -> Optional[Router]:
return self._parent_router
@ -176,53 +167,40 @@ class Router:
:param kwargs:
:return:
"""
chat: Optional[Chat] = None
from_user: Optional[User] = None
event: TelegramObject
if update.message:
update_type = "message"
from_user = update.message.from_user
chat = update.message.chat
event = update.message
elif update.edited_message:
update_type = "edited_message"
from_user = update.edited_message.from_user
chat = update.edited_message.chat
event = update.edited_message
elif update.channel_post:
update_type = "channel_post"
chat = update.channel_post.chat
event = update.channel_post
elif update.edited_channel_post:
update_type = "edited_channel_post"
chat = update.edited_channel_post.chat
event = update.edited_channel_post
elif update.inline_query:
update_type = "inline_query"
from_user = update.inline_query.from_user
event = update.inline_query
elif update.chosen_inline_result:
update_type = "chosen_inline_result"
from_user = update.chosen_inline_result.from_user
event = update.chosen_inline_result
elif update.callback_query:
update_type = "callback_query"
if update.callback_query.message:
chat = update.callback_query.message.chat
from_user = update.callback_query.from_user
event = update.callback_query
elif update.shipping_query:
update_type = "shipping_query"
from_user = update.shipping_query.from_user
event = update.shipping_query
elif update.pre_checkout_query:
update_type = "pre_checkout_query"
from_user = update.pre_checkout_query.from_user
event = update.pre_checkout_query
elif update.poll:
update_type = "poll"
event = update.poll
elif update.poll_answer:
update_type = "poll_answer"
event = update.poll_answer
else:
warnings.warn(
"Detected unknown update type.\n"
@ -232,76 +210,17 @@ class Router:
)
raise SkipHandler
return await self.listen_update(
update_type=update_type,
update=update,
event=event,
from_user=from_user,
chat=chat,
**kwargs,
)
async def listen_update(
self,
update_type: str,
update: Update,
event: TelegramObject,
from_user: Optional[User] = None,
chat: Optional[Chat] = None,
**kwargs: Any,
) -> Any:
"""
Listen update by current and child routers
:param update_type:
:param update:
:param event:
:param from_user:
:param chat:
:param kwargs:
:return:
"""
user_token = None
if from_user:
user_token = User.set_current(from_user)
chat_token = None
if chat:
chat_token = Chat.set_current(chat)
kwargs.update(event_update=update, event_router=self)
observer = self.observers[update_type]
try:
async for result in observer.trigger(event, update=update, **kwargs):
return result
response = await observer.trigger(event, update=update, **kwargs)
if response is NOT_HANDLED: # Resolve nested routers
for router in self.sub_routers:
try:
return await router.listen_update(
update_type=update_type,
update=update,
event=event,
from_user=from_user,
chat=chat,
**kwargs,
)
except SkipHandler:
response = await router.update.trigger(event=update, **kwargs)
if response is NOT_HANDLED:
continue
raise SkipHandler
except SkipHandler:
raise
except Exception as e:
async for result in self.errors.trigger(e, **kwargs):
return result
raise
finally:
if user_token:
User.reset_current(user_token)
if chat_token:
Chat.reset_current(chat_token)
return response
async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
"""
@ -312,8 +231,7 @@ class Router:
:return:
"""
kwargs.update(router=self)
async for _ in self.startup.trigger(*args, **kwargs): # pragma: no cover
pass
await self.startup.trigger(*args, **kwargs)
for router in self.sub_routers:
await router.emit_startup(*args, **kwargs)
@ -326,8 +244,7 @@ class Router:
:return:
"""
kwargs.update(router=self)
async for _ in self.shutdown.trigger(*args, **kwargs): # pragma: no cover
pass
await self.shutdown.trigger(*args, **kwargs)
for router in self.sub_routers:
await router.emit_shutdown(*args, **kwargs)

View file

@ -1,9 +1,10 @@
from __future__ import annotations
import contextvars
from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
from typing_extensions import Literal
if TYPE_CHECKING: # pragma: no cover
from typing_extensions import Literal
__all__ = ("ContextInstanceMixin", "DataMixin")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 51 KiB

Before After
Before After

View file

@ -39,7 +39,7 @@ dp.include_router(router1)
## Handling updates
All updates can be propagated to the dispatcher by `feed_update` method:
```
```python3
bot = Bot(...)
dp = Dispathcher()

View file

@ -0,0 +1,95 @@
# Middlewares
**aiogram** provides powerful mechanism for customizing event handlers via middlewares.
Middlewares in bot framework seems like Middlewares mechanism in web-frameworks
(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares),
[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/),
[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.)
with small difference - here is implemented two layers of middlewares (before and after filters).
!!! info
Middleware is function that triggered on every event received from
Telegram Bot API in many points on processing pipeline.
## Base theory
As many books and other literature in internet says:
> Middleware is reusable software that leverages patterns and frameworks to bridge
> the gap between the functional requirements of applications and the underlying operating systems,
> network protocol stacks, and databases.
Middleware can modify, extend or reject processing event in many places of pipeline.
## Basics
Middleware instance can be applied for every type of Telegram Event (Update, Message, etc.) in two places
1. Outer scope - before processing filters (`#!python <router>.<event>.outer_middleware(...)`)
2. Inner scope - after processing filters but before handler (`#!python <router>.<event>.middleware(...)`)
[![middlewares](../assets/images/basics_middleware.png)](../assets/images/basics_middleware.png)
_(Click on image to zoom it)_
!!! warning
Middleware should be subclass of `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) or any async callable
## Arguments specification
| Argument | Type | Description |
| - | - | - |
| `handler` | `#!python Callable[[T, Dict[str, Any]], Awaitable[Any]]` | Wrapped handler in middlewares chain |
| `event` | `#!python T` | Incoming event (Subclass of `TelegramObject`) |
| `data` | `#!python Dict[str, Any]` | Contextual data. Will be mapped to handler arguments |
## Examples
!!! danger
Middleware should always call `#!python await handler(event, data)` to propagate event for next middleware/handler
### Class-based
```python3
from aiogram import BaseMiddleware
from aiogram.api.types import Message
class CounterMiddleware(BaseMiddleware[Message]):
def __init__(self) -> None:
self.counter = 0
async def __call__(
self,
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
event: Message,
data: Dict[str, Any]
) -> Any:
self.counter += 1
data['counter'] = self.counter
return await handler(event, data)
```
and then
```python3
router = Router()
router.message.middleware(CounterMiddleware())
```
### Function-based
```python3
@dispatcher.update.outer_middleware()
async def database_transaction_middleware(
handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]],
event: Update,
data: Dict[str, Any]
) -> Any:
async with database.transaction():
return await handler(event, data)
```
## Facts
1. Middlewares from outer scope will be called on every incoming event
1. Middlewares from inner scope will be called only when filters pass
1. Inner middlewares is always calls for `Update` event type in due to all incoming updates going to specific event type handler through built in update handler

View file

@ -1,115 +0,0 @@
# Basics
All middlewares should be made with `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) as base class.
For example:
```python3
class MyMiddleware(BaseMiddleware): ...
```
And then use next pattern in naming callback functions in middleware: `on_{step}_{event}`
Where is:
- `#!python3 step`:
- `#!python3 pre_process`
- `#!python3 process`
- `#!python3 post_process`
- `#!python3 event`:
- `#!python3 update`
- `#!python3 message`
- `#!python3 edited_message`
- `#!python3 channel_post`
- `#!python3 edited_channel_post`
- `#!python3 inline_query`
- `#!python3 chosen_inline_result`
- `#!python3 callback_query`
- `#!python3 shipping_query`
- `#!python3 pre_checkout_query`
- `#!python3 poll`
- `#!python3 poll_answer`
- `#!python3 error`
## Connecting middleware with router
Middlewares can be connected with router by next ways:
1. `#!python3 router.use(MyMiddleware())` (**recommended**)
1. `#!python3 router.middleware.setup(MyMiddleware())`
1. `#!python3 MyMiddleware().setup(router.middleware)` (**not recommended**)
!!! warning
One instance of middleware **can't** be registered twice in single or many middleware managers
## The specification of step callbacks
### Pre-process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
Returns `#!python3 Any`
### Process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
Returns `#!python3 Any`
### Post-Process step
| Argument | Type | Description |
| --- | --- | --- |
| event name | Any of event type (Update, Message and etc.) | Event |
| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) |
| `#!python3 result` | `#!python3 Dict[str, Any]` | Response from handlers |
Returns `#!python3 Any`
## Full list of available callbacks
- `#!python3 on_pre_process_update` - will be triggered on **pre process** `#!python3 update` event
- `#!python3 on_process_update` - will be triggered on **process** `#!python3 update` event
- `#!python3 on_post_process_update` - will be triggered on **post process** `#!python3 update` event
- `#!python3 on_pre_process_message` - will be triggered on **pre process** `#!python3 message` event
- `#!python3 on_process_message` - will be triggered on **process** `#!python3 message` event
- `#!python3 on_post_process_message` - will be triggered on **post process** `#!python3 message` event
- `#!python3 on_pre_process_edited_message` - will be triggered on **pre process** `#!python3 edited_message` event
- `#!python3 on_process_edited_message` - will be triggered on **process** `#!python3 edited_message` event
- `#!python3 on_post_process_edited_message` - will be triggered on **post process** `#!python3 edited_message` event
- `#!python3 on_pre_process_channel_post` - will be triggered on **pre process** `#!python3 channel_post` event
- `#!python3 on_process_channel_post` - will be triggered on **process** `#!python3 channel_post` event
- `#!python3 on_post_process_channel_post` - will be triggered on **post process** `#!python3 channel_post` event
- `#!python3 on_pre_process_edited_channel_post` - will be triggered on **pre process** `#!python3 edited_channel_post` event
- `#!python3 on_process_edited_channel_post` - will be triggered on **process** `#!python3 edited_channel_post` event
- `#!python3 on_post_process_edited_channel_post` - will be triggered on **post process** `#!python3 edited_channel_post` event
- `#!python3 on_pre_process_inline_query` - will be triggered on **pre process** `#!python3 inline_query` event
- `#!python3 on_process_inline_query` - will be triggered on **process** `#!python3 inline_query` event
- `#!python3 on_post_process_inline_query` - will be triggered on **post process** `#!python3 inline_query` event
- `#!python3 on_pre_process_chosen_inline_result` - will be triggered on **pre process** `#!python3 chosen_inline_result` event
- `#!python3 on_process_chosen_inline_result` - will be triggered on **process** `#!python3 chosen_inline_result` event
- `#!python3 on_post_process_chosen_inline_result` - will be triggered on **post process** `#!python3 chosen_inline_result` event
- `#!python3 on_pre_process_callback_query` - will be triggered on **pre process** `#!python3 callback_query` event
- `#!python3 on_process_callback_query` - will be triggered on **process** `#!python3 callback_query` event
- `#!python3 on_post_process_callback_query` - will be triggered on **post process** `#!python3 callback_query` event
- `#!python3 on_pre_process_shipping_query` - will be triggered on **pre process** `#!python3 shipping_query` event
- `#!python3 on_process_shipping_query` - will be triggered on **process** `#!python3 shipping_query` event
- `#!python3 on_post_process_shipping_query` - will be triggered on **post process** `#!python3 shipping_query` event
- `#!python3 on_pre_process_pre_checkout_query` - will be triggered on **pre process** `#!python3 pre_checkout_query` event
- `#!python3 on_process_pre_checkout_query` - will be triggered on **process** `#!python3 pre_checkout_query` event
- `#!python3 on_post_process_pre_checkout_query` - will be triggered on **post process** `#!python3 pre_checkout_query` event
- `#!python3 on_pre_process_poll` - will be triggered on **pre process** `#!python3 poll` event
- `#!python3 on_process_poll` - will be triggered on **process** `#!python3 poll` event
- `#!python3 on_post_process_poll` - will be triggered on **post process** `#!python3 poll` event
- `#!python3 on_pre_process_poll_answer` - will be triggered on **pre process** `#!python3 poll_answer` event
- `#!python3 on_process_poll_answer` - will be triggered on **process** `#!python3 poll_answer` event
- `#!python3 on_post_process_poll_answer` - will be triggered on **post process** `#!python3 poll_answer` event
- `#!python3 on_pre_process_error` - will be triggered on **pre process** `#!python3 error` event
- `#!python3 on_process_error` - will be triggered on **process** `#!python3 error` event
- `#!python3 on_post_process_error` - will be triggered on **post process** `#!python3 error` event

View file

@ -1,77 +0,0 @@
# Overview
**aiogram** provides powerful mechanism for customizing event handlers via middlewares.
Middlewares in bot framework seems like Middlewares mechanism in web-frameworks
(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares),
[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/),
[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.)
with small difference - here is implemented many layers of processing
(named as [pipeline](#event-pipeline)).
!!! info
Middleware is function that triggered on every event received from
Telegram Bot API in many points on processing pipeline.
## Base theory
As many books and other literature in internet says:
> Middleware is reusable software that leverages patterns and frameworks to bridge
>the gap between the functional requirements of applications and the underlying operating systems,
> network protocol stacks, and databases.
Middleware can modify, extend or reject processing event before-,
on- or after- processing of that event.
[![middlewares](../../assets/images/basics_middleware.png)](../../assets/images/basics_middleware.png)
_(Click on image to zoom it)_
## Event pipeline
As described below middleware an interact with event in many stages of pipeline.
Simple workflow:
1. Dispatcher receive an [Update](../../api/types/update.md)
1. Call **pre-process** update middleware in all routers tree
1. Filter Update over handlers
1. Call **process** update middleware in all routers tree
1. Router detects event type (Message, Callback query, etc.)
1. Router triggers **pre-process** <event> middleware of specific type
1. Pass event over [filters](../filters/index.md) to detect specific handler
1. Call **process** <event> middleware for specific type (only when handler for this event exists)
1. *Do magick*. Call handler (Read more [Event observers](../router.md#event-observers))
1. Call **post-process** <event> middleware
1. Call **post-process** update middleware in all routers tree
1. Emit response into webhook (when it needed)
!!! warning
When filters does not match any handler with this event the `#!python3 process`
step will not be called.
!!! warning
When exception will be caused in handlers pipeline will be stopped immediately
and then start processing error via errors handler and it own middleware callbacks.
!!! warning
Middlewares for updates will be called for all routers in tree but callbacks for events
will be called only for specific branch of routers.
### Pipeline in pictures:
#### Simple pipeline
[![middlewares](../../assets/images/middleware_pipeline.png)](../../assets/images/middleware_pipeline.png)
_(Click on image to zoom it)_
#### Nested routers pipeline
[![middlewares](../../assets/images/middleware_pipeline_nested.png)](../../assets/images/middleware_pipeline_nested.png)
_(Click on image to zoom it)_
## Read more
- [Middleware Basics](basics.md)

View file

@ -20,7 +20,7 @@ Documentation for version 3.0 [WIP] [^1]
- [Supports Telegram Bot API v{!_api_version.md!}](api/index.md)
- [Updates router](dispatcher/index.md) (Blueprints)
- Finite State Machine
- [Middlewares](dispatcher/middlewares/index.md)
- [Middlewares](dispatcher/middlewares.md)
- [Replies into Webhook](https://core.telegram.org/bots/faq#how-can-i-make-requests-in-response-to-updates)

View file

@ -0,0 +1,12 @@
@font-face {
font-family: 'JetBrainsMono';
src: url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff2/JetBrainsMono-Regular.woff2') format('woff2'),
url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff/JetBrainsMono-Regular.woff') format('woff'),
url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/ttf/JetBrainsMono-Regular.ttf') format('truetype');
font-weight: 400;
font-style: normal;
}
code, kbd, pre {
font-family: "JetBrainsMono", "Roboto Mono", "Courier New", Courier, monospace;
}

View file

@ -16,6 +16,9 @@ theme:
favicon: 'assets/images/logo.png'
logo: 'assets/images/logo.png'
extra_css:
- stylesheets/extra.css
extra:
version: 3.0.0a3
@ -255,9 +258,8 @@ nav:
- dispatcher/class_based_handlers/pre_checkout_query.md
- dispatcher/class_based_handlers/shipping_query.md
- dispatcher/class_based_handlers/error.md
- Middlewares:
- dispatcher/middlewares/index.md
- dispatcher/middlewares/basics.md
- dispatcher/middlewares.md
- todo.md
- Build reports:
- reports.md

12
poetry.lock generated
View file

@ -884,7 +884,7 @@ category = "dev"
description = "YAML parser and emitter for Python"
name = "pyyaml"
optional = false
python-versions = "*"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "5.3.1"
[[package]]
@ -955,7 +955,7 @@ python-versions = "*"
version = "1.4.1"
[[package]]
category = "main"
category = "dev"
description = "Backported and Experimental Type Hints for Python 3.5+"
name = "typing-extensions"
optional = false
@ -1031,8 +1031,7 @@ fast = ["uvloop"]
proxy = ["aiohttp-socks"]
[metadata]
content-hash = "57137b60a539ba01e8df533db976e2f3eadec37e717cbefbe775dc021a8c2714"
content-hash = "768759359beca8b84811bfc21adac9649925cd22b87427a10608c9d1e16a0923"
python-versions = "^3.7"
[metadata.files]
@ -1284,11 +1283,6 @@ markupsafe = [
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"},
{file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"},
{file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"},
]
mccabe = [

View file

@ -40,7 +40,6 @@ aiofiles = "^0.4.0"
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true}
async_lru = "^1.0"
aiohttp-socks = {version = "^0.3.8", optional = true}
typing-extensions = "^3.7.4"
[tool.poetry.dev-dependencies]
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'"}
@ -68,6 +67,7 @@ markdown-include = "^0.5.1"
aiohttp-socks = "^0.3.4"
pre-commit = "^2.3.0"
packaging = "^20.3"
typing-extensions = "^3.7.4"
[tool.poetry.extras]
fast = ["uvloop"]
@ -94,7 +94,7 @@ include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 99
known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pydantic", "pytest", "typing_extensions"]
known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pkg_resources", "pydantic", "pytest"]
[build-system]
requires = ["poetry>=0.12"]

View file

@ -11,7 +11,13 @@ class TestInlineQuery:
offset="",
)
kwargs = dict(results=[], cache_time=123, next_offset="123", switch_pm_text="foo", switch_pm_parameter="foo")
kwargs = dict(
results=[],
cache_time=123,
next_offset="123",
switch_pm_text="foo",
switch_pm_parameter="foo",
)
api_method = inline_query.answer(**kwargs)

View file

@ -1,5 +1,5 @@
from aiogram.api.methods import AnswerShippingQuery
from aiogram.api.types import ShippingAddress, ShippingQuery, User, ShippingOption, LabeledPrice
from aiogram.api.types import LabeledPrice, ShippingAddress, ShippingOption, ShippingQuery, User
class TestInlineQuery:
@ -19,7 +19,8 @@ class TestInlineQuery:
)
shipping_options = [
ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])]
ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])
]
kwargs = dict(ok=True, shipping_options=shipping_options, error_message="foo")

View file

@ -1,6 +1,6 @@
import pytest
from aiogram.dispatcher.event.observer import TelegramEventObserver
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router
from tests.deprecated import check_deprecated

View file

@ -9,6 +9,7 @@ from aiogram import Bot
from aiogram.api.methods import GetMe, GetUpdates, SendMessage
from aiogram.api.types import Chat, Message, Update, User
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NOT_HANDLED
from aiogram.dispatcher.router import Router
from tests.mocked_bot import MockedBot
@ -63,7 +64,7 @@ class TestDispatcher:
return message.text
results_count = 0
async for result in dp.feed_update(
result = await dp.feed_update(
bot=bot,
update=Update(
update_id=42,
@ -75,11 +76,9 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
),
):
results_count += 1
assert result == "test"
assert results_count == 1
)
results_count += 1
assert result == "test"
@pytest.mark.asyncio
async def test_feed_raw_update(self):
@ -91,8 +90,7 @@ class TestDispatcher:
assert message.text == "test"
return message.text
handled = False
async for result in dp.feed_raw_update(
result = await dp.feed_raw_update(
bot=bot,
update={
"update_id": 42,
@ -101,13 +99,11 @@ class TestDispatcher:
"date": int(time.time()),
"text": "test",
"chat": {"id": 42, "type": "private"},
"user": {"id": 42, "is_bot": False, "first_name": "Test"},
"from": {"id": 42, "is_bot": False, "first_name": "Test"},
},
},
):
handled = True
assert result == "test"
assert handled
)
assert result == "test"
@pytest.mark.asyncio
async def test_listen_updates(self, bot: MockedBot):
@ -136,7 +132,8 @@ class TestDispatcher:
async def test_process_update_empty(self, bot: MockedBot):
dispatcher = Dispatcher()
assert not await dispatcher.process_update(bot=bot, update=Update(update_id=42))
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
assert result
@pytest.mark.asyncio
async def test_process_update_handled(self, bot: MockedBot):
@ -146,22 +143,25 @@ class TestDispatcher:
async def update_handler(update: Update):
pass
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
@pytest.mark.asyncio
async def test_process_update_call_request(self, bot: MockedBot):
dispatcher = Dispatcher()
@dispatcher.update()
async def update_handler(update: Update):
async def message_handler(update: Update):
return GetMe()
dispatcher.update.handlers.reverse()
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request",
new_callable=CoroutineMock,
) as mocked_silent_call_request:
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
mocked_silent_call_request.assert_awaited_once()
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
print(result)
mocked_silent_call_request.assert_awaited()
@pytest.mark.asyncio
async def test_process_update_exception(self, bot: MockedBot, caplog):
@ -171,7 +171,7 @@ class TestDispatcher:
async def update_handler(update: Update):
raise Exception("Kaboom!")
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
log_records = [rec.message for rec in caplog.records]
assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0]
@ -184,7 +184,7 @@ class TestDispatcher:
yield Update(update_id=42)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates:
@ -203,7 +203,7 @@ class TestDispatcher:
yield Update(update_id=42)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch(
"aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock
) as mocked_emit_startup, patch(

View file

@ -0,0 +1,59 @@
import functools
from typing import Any
import pytest
from aiogram.dispatcher.event.event import EventObserver
from aiogram.dispatcher.event.handler import HandlerObject
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
async def my_handler(value: str, index: int = 0) -> Any:
return value
class TestEventObserver:
@pytest.mark.parametrize("via_decorator", [True, False])
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, via_decorator, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
if via_decorator:
register_result = observer()(wrapped_handler)
assert register_result == wrapped_handler
else:
register_result = observer.register(wrapped_handler)
assert register_result is None
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger(self):
observer = EventObserver()
observer.register(my_handler)
observer.register(lambda e: True)
observer.register(my_handler)
assert observer.handlers[0].awaitable
assert not observer.handlers[1].awaitable
assert observer.handlers[2].awaitable
with patch(
"aiogram.dispatcher.event.handler.CallableMixin.call", new_callable=CoroutineMock,
) as mocked_my_handler:
results = await observer.trigger("test")
assert results is None
mocked_my_handler.assert_awaited_with("test")
assert mocked_my_handler.call_count == 3

View file

@ -5,11 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
import pytest
from aiogram.api.types import Chat, Message, User
from aiogram.dispatcher.event.bases import SkipHandler
from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.router import Router
# TODO: Test middlewares in routers tree
async def my_handler(event: Any, index: int = 0) -> Any:
return event
@ -38,54 +41,6 @@ class MyFilter3(MyFilter1):
pass
class TestEventObserver:
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer.register(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters_via_decorator(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer()(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger_accepted_bool(self):
observer = EventObserver()
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@pytest.mark.asyncio
async def test_trigger_with_skip(self):
observer = EventObserver()
observer.register(skip_my_handler)
observer.register(my_handler)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42, 42]
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
@ -198,8 +153,8 @@ class TestTelegramEventObserver:
from_user=User(id=42, is_bot=False, first_name="Test"),
)
results = [result async for result in observer.trigger(message)]
assert results == [message]
results = await observer.trigger(message)
assert results is message
@pytest.mark.parametrize(
"count,handler,filters",
@ -223,15 +178,58 @@ class TestTelegramEventObserver:
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
#
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]
async def mix_unnecessary_data(event):
return {"a": 1}
async def mix_data(event):
return {"b": 2}
async def handler(event, **kwargs):
return False
observer.register(
pipe_handler, mix_unnecessary_data, handler
) # {"a": 1} should not be in result
observer.register(pipe_handler, mix_data)
results = await observer.trigger(42)
assert results == ((42,), {"b": 2})
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
def test_register_middleware(self, middleware_type):
event_observer = TelegramEventObserver(Router(), "test")
middlewares = getattr(event_observer, f"{middleware_type}s")
decorator = getattr(event_observer, middleware_type)
@decorator
async def my_middleware1(handler, event, data):
pass
assert my_middleware1 is not None
assert my_middleware1.__name__ == "my_middleware1"
assert my_middleware1 in middlewares
@decorator()
async def my_middleware2(handler, event, data):
pass
assert my_middleware2 is not None
assert my_middleware2.__name__ == "my_middleware2"
assert my_middleware2 in middlewares
async def my_middleware3(handler, event, data):
pass
decorator(my_middleware3)
assert my_middleware3 is not None
assert my_middleware3.__name__ == "my_middleware3"
assert my_middleware3 in middlewares
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]

View file

@ -1,257 +0,0 @@
import datetime
from typing import Any, Dict, Type
import pytest
from aiogram.api.types import (
CallbackQuery,
Chat,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
User,
)
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
class MyMiddleware(BaseMiddleware):
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
return "channel_post"
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
return "poll_answer"
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
return "channel_post"
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
return "poll_answer"
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
return "update"
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "message"
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_message"
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "channel_post"
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_channel_post"
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
return "inline_query"
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
return "chosen_inline_result"
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
return "callback_query"
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
return "shipping_query"
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
return "pre_checkout_query"
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
return "poll"
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
return "poll_answer"
async def on_post_process_error(
self, exception: Exception, data: Dict[str, Any], result: Any
) -> Any:
return "error"
UPDATE = Update(update_id=42)
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
POLL_ANSWER = PollAnswer(
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
)
class TestBaseMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
)
@pytest.mark.parametrize(
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
)
@pytest.mark.parametrize(
"event_name,event",
[
["update", UPDATE],
["message", MESSAGE],
["poll_answer", POLL_ANSWER],
["error", Exception("KABOOM")],
],
)
async def test_trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
middleware_cls: Type[BaseMiddleware],
should_be_awaited: bool,
):
middleware = middleware_cls()
with patch(
f"tests.test_dispatcher.test_middlewares.test_base."
f"MyMiddleware.on_{step.value}_{event_name}",
new_callable=CoroutineMock,
) as mocked_call:
response = await middleware.trigger(
step=step, event_name=event_name, event=event, data={}
)
if should_be_awaited:
mocked_call.assert_awaited()
assert response is not None
else:
mocked_call.assert_not_awaited()
assert response is None
def test_not_configured(self):
middleware = BaseMiddleware()
assert not middleware.configured
with pytest.raises(RuntimeError):
manager = middleware.manager

View file

@ -1,82 +0,0 @@
import pytest
from aiogram import Router
from aiogram.api.types import Update
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.dispatcher.middlewares.types import MiddlewareStep
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
@pytest.fixture("function")
def router():
return Router()
@pytest.fixture("function")
def manager(router: Router):
return MiddlewareManager(router)
class TestManager:
def test_setup(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
returned = manager.setup(middleware)
assert returned is middleware
assert middleware.configured
assert middleware.manager is manager
assert middleware in manager
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
with pytest.raises(TypeError):
assert manager.setup(obj)
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
new_manager = MiddlewareManager(router)
with pytest.raises(ValueError):
new_manager.setup(middleware)
with pytest.raises(ValueError):
middleware.setup(new_manager)
def test_configure_twice(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
manager.setup(middleware)
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
middleware.setup(manager)
@pytest.mark.asyncio
@pytest.mark.parametrize("count", range(5))
async def test_trigger(self, manager: MiddlewareManager, count: int):
for _ in range(count):
manager.setup(BaseMiddleware())
with patch(
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
new_callable=CoroutineMock,
) as mocked_call:
await manager.trigger(
step=MiddlewareStep.PROCESS,
event_name="update",
event=Update(update_id=42),
data={},
result=None,
reverse=True,
)
assert mocked_call.await_count == count

View file

@ -10,6 +10,7 @@ from aiogram.api.types import (
InlineQuery,
Message,
Poll,
PollAnswer,
PollOption,
PreCheckoutQuery,
ShippingAddress,
@ -17,8 +18,8 @@ from aiogram.api.types import (
Update,
User,
)
from aiogram.dispatcher.event.observer import SkipHandler, skip
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.event.bases import NOT_HANDLED, SkipHandler, skip
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect
@ -274,12 +275,26 @@ class TestRouter:
False,
False,
),
pytest.param(
"poll_answer",
Update(
update_id=42,
poll_answer=PollAnswer(
poll_id="poll id",
user=User(id=42, is_bot=False, first_name="Test"),
option_ids=[42],
),
),
False,
True,
),
],
)
async def test_listen_update(
self, event_type: str, update: Update, has_chat: bool, has_user: bool
):
router = Router()
router.update.outer_middleware(UserContextMiddleware())
observer = router.observers[event_type]
@observer()
@ -291,7 +306,7 @@ class TestRouter:
assert User.get_current(False)
return kwargs
result = await router._listen_update(update, test="PASS")
result = await router.update.trigger(update, test="PASS")
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@ -313,26 +328,26 @@ class TestRouter:
async def handler(event: Any):
pass
with pytest.raises(SkipHandler):
await router._listen_update(
Update(
update_id=42,
poll=Poll(
id="poll id",
question="Q?",
options=[
PollOption(text="A1", voter_count=2),
PollOption(text="A2", voter_count=3),
],
is_closed=False,
is_anonymous=False,
type="quiz",
allows_multiple_answers=False,
total_voter_count=0,
correct_option_id=0,
),
)
response = await router._listen_update(
Update(
update_id=42,
poll=Poll(
id="poll id",
question="Q?",
options=[
PollOption(text="A1", voter_count=2),
PollOption(text="A2", voter_count=3),
],
is_closed=False,
is_anonymous=False,
type="quiz",
allows_multiple_answers=False,
total_voter_count=0,
correct_option_id=0,
),
)
)
assert response is NOT_HANDLED
@pytest.mark.asyncio
async def test_nested_router_listen_update(self):
@ -345,8 +360,6 @@ class TestRouter:
@observer()
async def my_handler(event: Message, **kwargs: Any):
assert Chat.get_current(False)
assert User.get_current(False)
return kwargs
update = Update(
@ -409,14 +422,6 @@ class TestRouter:
await router1.emit_shutdown()
assert results == [2, 1, 2]
def test_use(self):
router = Router()
middleware = router.use(BaseMiddleware())
assert isinstance(middleware, BaseMiddleware)
assert middleware.configured
assert middleware.manager == router.middleware
def test_skip(self):
with pytest.raises(SkipHandler):
skip()
@ -444,37 +449,20 @@ class TestRouter:
),
)
with pytest.raises(Exception, match="KABOOM"):
await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
await root_router.update.trigger(update)
@root_router.errors()
async def root_error_handler(exception: Exception):
async def root_error_handler(event: Update, exception: Exception):
return exception
response = await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
response = await root_router.update.trigger(update)
assert isinstance(response, Exception)
assert str(response) == "KABOOM"
@router.errors()
async def error_handler(exception: Exception):
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
response = await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
response = await root_router.update.trigger(update)
assert response == "KABOOM"