diff --git a/Makefile b/Makefile index c14a3781..cf2e405c 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,6 @@ python := $(py) python reports_dir := reports -.PHONY: help help: @echo "=======================================================================================" @echo " aiogram build tools " @@ -45,12 +44,10 @@ help: # Environment # ================================================================================================= -.PHONY: install install: $(base_python) -m pip install --user -U poetry poetry install -.PHONY: clean clean: rm -rf `find . -name __pycache__` rm -f `find . -type f -name '*.py[co]' ` @@ -68,65 +65,56 @@ clean: # Code quality # ================================================================================================= -.PHONY: isort isort: $(py) isort -rc aiogram tests -.PHONY: black black: $(py) black aiogram tests -.PHONY: flake8 flake8: $(py) flake8 aiogram test -.PHONY: flake8-report flake8-report: mkdir -p $(reports_dir)/flake8 $(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram test -.PHONY: mypy mypy: $(py) mypy aiogram -.PHONY: mypy-report mypy-report: $(py) mypy aiogram --html-report $(reports_dir)/typechecking -.PHONY: lint lint: isort black flake8 mypy # ================================================================================================= # Tests # ================================================================================================= -.PHONY: test test: $(py) pytest --cov=aiogram --cov-config .coveragerc tests/ -.PHONY: test-coverage test-coverage: mkdir -p $(reports_dir)/tests/ $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ + + +test-coverage-report: $(py) coverage html -d $(reports_dir)/coverage -.PHONY: test-coverage-report -test-coverage-report: +test-coverage-view: + $(py) coverage html -d $(reports_dir)/coverage python -c "import webbrowser; webbrowser.open('file://$(shell pwd)/reports/coverage/index.html')" # ================================================================================================= # Docs # ================================================================================================= -.PHONY: docs docs: $(py) mkdocs build -.PHONY: docs-serve docs-serve: $(py) mkdocs serve -.PHONY: docs-copy-reports docs-copy-reports: mv $(reports_dir)/* site/reports @@ -134,9 +122,7 @@ docs-copy-reports: # Project # ================================================================================================= -.PHONY: build build: clean flake8-report mypy-report test-coverage docs docs-copy-reports mkdir -p site/simple poetry build mv dist site/simple/aiogram - diff --git a/aiogram/__init__.py b/aiogram/__init__.py index 0b2a8cc7..3df77008 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -1,3 +1,5 @@ +from pkg_resources import get_distribution + from .api import methods, types from .api.client import session from .api.client.bot import Bot @@ -28,5 +30,5 @@ __all__ = ( "handler", ) -__version__ = "3.0.0a4" +__version__ = get_distribution(dist=__package__).version __api_version__ = "4.8" diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 8960769d..1c4b08aa 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -11,6 +11,8 @@ from ..api.client.bot import Bot from ..api.methods import TelegramMethod from ..api.types import Update, User from ..utils.exceptions import TelegramAPIError +from .event.bases import NOT_HANDLED +from .middlewares.user_context import UserContextMiddleware from .router import Router @@ -23,6 +25,9 @@ class Dispatcher(Router): super(Dispatcher, self).__init__(**kwargs) self._running_lock = Lock() + # Default middleware is needed for contextual features + self.update.outer_middleware(UserContextMiddleware()) + @property def parent_router(self) -> None: """ @@ -42,9 +47,7 @@ class Dispatcher(Router): """ raise RuntimeError("Dispatcher can not be attached to another Router.") - async def feed_update( - self, bot: Bot, update: Update, **kwargs: Any - ) -> AsyncGenerator[Any, None]: + async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: """ Main entry point for incoming updates @@ -57,9 +60,9 @@ class Dispatcher(Router): Bot.set_current(bot) try: - async for result in self.update.trigger(update, bot=bot, **kwargs): - handled = True - yield result + response = await self.update.trigger(update, bot=bot, **kwargs) + handled = response is not NOT_HANDLED + return response finally: finish_time = loop.time() duration = (finish_time - start_time) * 1000 @@ -71,9 +74,7 @@ class Dispatcher(Router): bot.id, ) - async def feed_raw_update( - self, bot: Bot, update: Dict[str, Any], **kwargs: Any - ) -> AsyncGenerator[Any, None]: + async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any: """ Main entry point for incoming updates with automatic Dict->Update serializer @@ -82,8 +83,7 @@ class Dispatcher(Router): :param kwargs: """ parsed_update = Update(**update) - async for result in self.feed_update(bot=bot, update=parsed_update, **kwargs): - yield result + return await self.feed_update(bot=bot, update=parsed_update, **kwargs) @classmethod async def _listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]: @@ -114,7 +114,7 @@ class Dispatcher(Router): # For debugging here is added logging. loggers.dispatcher.error("Failed to make answer: %s: %s", e.__class__.__name__, e) - async def process_update( + async def _process_update( self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any ) -> bool: """ @@ -126,11 +126,13 @@ class Dispatcher(Router): :param kwargs: contextual data for middlewares, filters and handlers :return: status """ + handled = False try: - async for result in self.feed_update(bot, update, **kwargs): - if call_answer and isinstance(result, TelegramMethod): - await self._silent_call_request(bot=bot, result=result) - return True + response = await self.feed_update(bot, update, **kwargs) + handled = handled is not NOT_HANDLED + if call_answer and isinstance(response, TelegramMethod): + await self._silent_call_request(bot=bot, result=response) + return handled except Exception as e: loggers.dispatcher.exception( @@ -142,8 +144,6 @@ class Dispatcher(Router): ) return True # because update was processed but unsuccessful - return False - async def _polling(self, bot: Bot, **kwargs: Any) -> None: """ Internal polling process @@ -153,16 +153,14 @@ class Dispatcher(Router): :return: """ async for update in self._listen_updates(bot): - await self.process_update(bot=bot, update=update, **kwargs) + await self._process_update(bot=bot, update=update, **kwargs) async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: """ The same with `Dispatcher.process_update()` but returns real response instead of bool """ try: - async for result in self.feed_update(bot, update, **kwargs): - return result - + return await self.feed_update(bot, update, **kwargs) except Exception as e: loggers.dispatcher.exception( "Cause exception while process update id=%d by bot id=%d\n%s: %s", @@ -196,10 +194,10 @@ class Dispatcher(Router): def process_response(task: Future[Any]) -> None: warnings.warn( - f"Detected slow response into webhook.\n" - f"Telegram is waiting for response only first 60 seconds and then re-send update.\n" - f"For preventing this situation response into webhook returned immediately " - f"and handler is moved to background and still processing update.", + "Detected slow response into webhook.\n" + "Telegram is waiting for response only first 60 seconds and then re-send update.\n" + "For preventing this situation response into webhook returned immediately " + "and handler is moved to background and still processing update.", RuntimeWarning, ) try: diff --git a/aiogram/dispatcher/event/bases.py b/aiogram/dispatcher/event/bases.py new file mode 100644 index 00000000..d255e3ae --- /dev/null +++ b/aiogram/dispatcher/event/bases.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union +from unittest.mock import sentinel + +from ...api.types import TelegramObject +from ..middlewares.base import BaseMiddleware + +NextMiddlewareType = Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]] +MiddlewareType = Union[ + BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]] +] + +NOT_HANDLED = sentinel.NOT_HANDLED + + +class SkipHandler(Exception): + pass + + +class CancelHandler(Exception): + pass + + +def skip(message: Optional[str] = None) -> NoReturn: + """ + Raise an SkipHandler + """ + raise SkipHandler(message or "Event skipped") diff --git a/aiogram/dispatcher/event/event.py b/aiogram/dispatcher/event/event.py new file mode 100644 index 00000000..29aa4580 --- /dev/null +++ b/aiogram/dispatcher/event/event.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any, Callable, List + +from .handler import CallbackType, HandlerObject, HandlerType + + +class EventObserver: + """ + Simple events observer + """ + + def __init__(self) -> None: + self.handlers: List[HandlerObject] = [] + + def register(self, callback: HandlerType) -> None: + """ + Register callback with filters + """ + self.handlers.append(HandlerObject(callback=callback)) + + async def trigger(self, *args: Any, **kwargs: Any) -> None: + """ + Propagate event to handlers. + Handler will be called when all its filters is pass. + """ + for handler in self.handlers: + await handler.call(*args, **kwargs) + + def __call__(self) -> Callable[[CallbackType], CallbackType]: + """ + Decorator for registering event handlers + """ + + def wrapper(callback: CallbackType) -> CallbackType: + self.register(callback) + return callback + + return wrapper diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index d5a59277..6c2a5e57 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -1,3 +1,5 @@ +import asyncio +import contextvars import inspect from dataclasses import dataclass, field from functools import partial @@ -6,9 +8,9 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler -CallbackType = Callable[[Any], Awaitable[Any]] -SyncFilter = Callable[[Any], Any] -AsyncFilter = Callable[[Any], Awaitable[Any]] +CallbackType = Callable[..., Awaitable[Any]] +SyncFilter = Callable[..., Any] +AsyncFilter = Callable[..., Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] HandlerType = Union[FilterType, Type[BaseHandler]] @@ -40,7 +42,11 @@ class CallableMixin: wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) if self.awaitable: return await wrapped() - return wrapped() + + loop = asyncio.get_event_loop() + context = contextvars.copy_context() + wrapped = partial(context.run, wrapped) + return await loop.run_in_executor(None, wrapped) @dataclass @@ -60,11 +66,11 @@ class HandlerObject(CallableMixin): async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: if not self.filters: - return True, {} + return True, kwargs for event_filter in self.filters: check = await event_filter.call(*args, **kwargs) if not check: - return False, {} + return False, kwargs if isinstance(check, dict): kwargs.update(check) return True, kwargs diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/telegram.py similarity index 53% rename from aiogram/dispatcher/event/observer.py rename to aiogram/dispatcher/event/telegram.py index cea2eb6a..e72d5db3 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/telegram.py @@ -1,93 +1,33 @@ from __future__ import annotations +import functools from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Callable, - Dict, - Generator, - List, - NoReturn, - Optional, - Type, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union from pydantic import ValidationError +from ...api.types import TelegramObject from ..filters.base import BaseFilter -from ..middlewares.types import MiddlewareStep, UpdateType +from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType if TYPE_CHECKING: # pragma: no cover from aiogram.dispatcher.router import Router -class SkipHandler(Exception): - pass - - -class CancelHandler(Exception): - pass - - -def skip(message: Optional[str] = None) -> NoReturn: - """ - Raise an SkipHandler - """ - raise SkipHandler(message or "Event skipped") - - -class EventObserver: - """ - Base events observer - """ - - def __init__(self) -> None: - self.handlers: List[HandlerObject] = [] - - def register(self, callback: HandlerType) -> HandlerType: - """ - Register callback with filters - """ - self.handlers.append(HandlerObject(callback=callback)) - return callback - - async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: - """ - Propagate event to handlers. - Handler will be called when all its filters is pass. - """ - for handler in self.handlers: - try: - yield await handler.call(*args, **kwargs) - except SkipHandler: - continue - - def __call__(self) -> Callable[[CallbackType], CallbackType]: - """ - Decorator for registering event handlers - """ - - def wrapper(callback: CallbackType) -> CallbackType: - self.register(callback) - return callback - - return wrapper - - -class TelegramEventObserver(EventObserver): +class TelegramEventObserver: """ Event observer for Telegram events """ def __init__(self, router: Router, event_name: str) -> None: - super().__init__() - self.router: Router = router self.event_name: str = event_name + + self.handlers: List[HandlerObject] = [] self.filters: List[Type[BaseFilter]] = [] + self.outer_middlewares: List[MiddlewareType] = [] + self.middlewares: List[MiddlewareType] = [] def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: """ @@ -144,37 +84,6 @@ class TelegramEventObserver(EventObserver): return filters - async def trigger_middleware( - self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None, - ) -> None: - """ - Trigger middlewares chain - - :param step: - :param event: - :param data: - :param result: - :return: - """ - reverse = step == MiddlewareStep.POST_PROCESS - recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS - - if self.event_name == "update": - routers = self.router.chain - else: - routers = self.router.chain_head - for router in routers: - await router.middleware.trigger( - step=step, - event_name=self.event_name, - event=event, - data=data, - result=result, - reverse=reverse, - ) - if not recursive: - break - def register( self, callback: HandlerType, *filters: FilterType, **bound_filters: Any ) -> HandlerType: @@ -190,32 +99,39 @@ class TelegramEventObserver(EventObserver): ) return callback - async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + @classmethod + def _wrap_middleware( + cls, middlewares: List[MiddlewareType], handler: HandlerType + ) -> NextMiddlewareType: + @functools.wraps(handler) + def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any: + return handler(event, **kwargs) + + middleware = mapper + for m in reversed(middlewares): + middleware = functools.partial(m, middleware) + return middleware + + async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any: """ Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - event = args[0] - await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs) + wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger) + return await wrapped_outer(event, kwargs) + + async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: for handler in self.handlers: - result, data = await handler.check(*args, **kwargs) + result, data = await handler.check(event, **kwargs) if result: kwargs.update(data) - await self.trigger_middleware( - step=MiddlewareStep.PROCESS, event=event, data=kwargs - ) try: - response = await handler.call(*args, **kwargs) - await self.trigger_middleware( - step=MiddlewareStep.POST_PROCESS, - event=event, - data=kwargs, - result=response, - ) - yield response + wrapped_inner = self._wrap_middleware(self.middlewares, handler.call) + return await wrapped_inner(event, kwargs) except SkipHandler: continue - break + + return NOT_HANDLED def __call__( self, *args: FilterType, **bound_filters: BaseFilter @@ -229,3 +145,45 @@ class TelegramEventObserver(EventObserver): return callback return wrapper + + def middleware( + self, middleware: Optional[MiddlewareType] = None, + ) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]: + """ + Decorator for registering inner middlewares + + Usage: + >>> @.middleware() # via decorator (variant 1) + >>> @.middleware # via decorator (variant 2) + >>> async def my_middleware(handler, event, data): ... + >>> .middleware(middleware) # via method + """ + + def wrapper(m: MiddlewareType) -> MiddlewareType: + self.middlewares.append(m) + return m + + if middleware is None: + return wrapper + return wrapper(middleware) + + def outer_middleware( + self, middleware: Optional[MiddlewareType] = None, + ) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]: + """ + Decorator for registering outer middlewares + + Usage: + >>> @.outer_middleware() # via decorator (variant 1) + >>> @.outer_middleware # via decorator (variant 2) + >>> async def my_middleware(handler, event, data): ... + >>> .outer_middleware(my_middleware) # via method + """ + + def wrapper(m: MiddlewareType) -> MiddlewareType: + self.outer_middlewares.append(m) + return m + + if middleware is None: + return wrapper + return wrapper(middleware) diff --git a/aiogram/dispatcher/middlewares/abstract.py b/aiogram/dispatcher/middlewares/abstract.py deleted file mode 100644 index eac16534..00000000 --- a/aiogram/dispatcher/middlewares/abstract.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Optional - -from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType - -if TYPE_CHECKING: # pragma: no cover - from aiogram.dispatcher.middlewares.manager import MiddlewareManager - - -class AbstractMiddleware(ABC): - """ - Abstract class for middleware. - """ - - def __init__(self) -> None: - self._manager: Optional[MiddlewareManager] = None - - @property - def manager(self) -> MiddlewareManager: - """ - Instance of MiddlewareManager - """ - if self._manager is None: - raise RuntimeError("Middleware is not configured!") - return self._manager - - def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware: - """ - Mark middleware as configured - - :param manager: - :param _stack_level: - :return: - """ - if self.configured: - return manager.setup(self, _stack_level=_stack_level + 1) - - self._manager = manager - return self - - @property - def configured(self) -> bool: - """ - Check middleware is configured - - :return: - """ - return bool(self._manager) - - @abstractmethod - async def trigger( - self, - step: MiddlewareStep, - event_name: str, - event: UpdateType, - data: Dict[str, Any], - result: Any = None, - ) -> Any: # pragma: no cover - pass diff --git a/aiogram/dispatcher/middlewares/base.py b/aiogram/dispatcher/middlewares/base.py index 8766f9dc..f0db86ec 100644 --- a/aiogram/dispatcher/middlewares/base.py +++ b/aiogram/dispatcher/middlewares/base.py @@ -1,317 +1,15 @@ -from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Dict, Generic, TypeVar -from typing import TYPE_CHECKING, Any, Dict - -from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware -from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType - -if TYPE_CHECKING: # pragma: no cover - from aiogram.api.types import ( - CallbackQuery, - ChosenInlineResult, - InlineQuery, - Message, - Poll, - PollAnswer, - PreCheckoutQuery, - ShippingQuery, - Update, - ) +T = TypeVar("T") -class BaseMiddleware(AbstractMiddleware): - """ - Base class for middleware. - - All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message". - """ - - async def trigger( +class BaseMiddleware(ABC, Generic[T]): + @abstractmethod + async def __call__( self, - step: MiddlewareStep, - event_name: str, - event: UpdateType, + handler: Callable[[T, Dict[str, Any]], Awaitable[Any]], + event: T, data: Dict[str, Any], - result: Any = None, - ) -> Any: - """ - Trigger action. - - :param step: - :param event_name: - :param event: - :param data: - :param result: - :return: - """ - handler_name = f"on_{step.value}_{event_name}" - handler = getattr(self, handler_name, None) - if not handler: - return None - args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data) - return await handler(*args) - - if TYPE_CHECKING: # pragma: no cover - # ============================================================================================= - # Event that triggers before process - # ============================================================================================= - async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any: - """ - Event that triggers before process update - """ - - async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any: - """ - Event that triggers before process message - """ - - async def on_pre_process_edited_message( - self, edited_message: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process edited_message - """ - - async def on_pre_process_channel_post( - self, channel_post: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process channel_post - """ - - async def on_pre_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process edited_channel_post - """ - - async def on_pre_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process inline_query - """ - - async def on_pre_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process chosen_inline_result - """ - - async def on_pre_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process callback_query - """ - - async def on_pre_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process shipping_query - """ - - async def on_pre_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process pre_checkout_query - """ - - async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: - """ - Event that triggers before process poll - """ - - async def on_pre_process_poll_answer( - self, poll_answer: PollAnswer, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers before process poll_answer - """ - - async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: - """ - Event that triggers before process error - """ - - # ============================================================================================= - # Event that triggers on process after filters. - # ============================================================================================= - async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any: - """ - Event that triggers on process update - """ - - async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any: - """ - Event that triggers on process message - """ - - async def on_process_edited_message( - self, edited_message: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process edited_message - """ - - async def on_process_channel_post( - self, channel_post: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process channel_post - """ - - async def on_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process edited_channel_post - """ - - async def on_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process inline_query - """ - - async def on_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process chosen_inline_result - """ - - async def on_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process callback_query - """ - - async def on_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process shipping_query - """ - - async def on_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process pre_checkout_query - """ - - async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: - """ - Event that triggers on process poll - """ - - async def on_process_poll_answer( - self, poll_answer: PollAnswer, data: Dict[str, Any] - ) -> Any: - """ - Event that triggers on process poll_answer - """ - - async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: - """ - Event that triggers on process error - """ - - # ============================================================================================= - # Event that triggers after process . - # ============================================================================================= - async def on_post_process_update( - self, update: Update, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing update - """ - - async def on_post_process_message( - self, message: Message, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing message - """ - - async def on_post_process_edited_message( - self, edited_message: Message, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing edited_message - """ - - async def on_post_process_channel_post( - self, channel_post: Message, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing channel_post - """ - - async def on_post_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing edited_channel_post - """ - - async def on_post_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing inline_query - """ - - async def on_post_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing chosen_inline_result - """ - - async def on_post_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing callback_query - """ - - async def on_post_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing shipping_query - """ - - async def on_post_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing pre_checkout_query - """ - - async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any: - """ - Event that triggers after processing poll - """ - - async def on_post_process_poll_answer( - self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing poll_answer - """ - - async def on_post_process_error( - self, exception: Exception, data: Dict[str, Any], result: Any - ) -> Any: - """ - Event that triggers after processing error - """ + ) -> Any: # pragma: no cover + pass diff --git a/aiogram/dispatcher/middlewares/error.py b/aiogram/dispatcher/middlewares/error.py new file mode 100644 index 00000000..438277b1 --- /dev/null +++ b/aiogram/dispatcher/middlewares/error.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict + +from ...api.types import Update +from ..event.bases import NOT_HANDLED, CancelHandler, SkipHandler +from .base import BaseMiddleware + +if TYPE_CHECKING: # pragma: no cover + from ..router import Router + + +class ErrorsMiddleware(BaseMiddleware[Update]): + def __init__(self, router: Router): + self.router = router + + async def __call__( + self, + handler: Callable[[Any, Dict[str, Any]], Awaitable[Any]], + event: Any, + data: Dict[str, Any], + ) -> Any: + try: + return await handler(event, data) + except (SkipHandler, CancelHandler): # pragma: no cover + raise + except Exception as e: + response = await self.router.errors.trigger(event, exception=e, **data) + if response is NOT_HANDLED: + raise + return response diff --git a/aiogram/dispatcher/middlewares/manager.py b/aiogram/dispatcher/middlewares/manager.py deleted file mode 100644 index 39a6230d..00000000 --- a/aiogram/dispatcher/middlewares/manager.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, List -from warnings import warn - -from .abstract import AbstractMiddleware -from .types import MiddlewareStep, UpdateType - -if TYPE_CHECKING: # pragma: no cover - from aiogram.dispatcher.router import Router - - -class MiddlewareManager: - """ - Middleware manager. - """ - - def __init__(self, router: Router) -> None: - self.router = router - self.middlewares: List[AbstractMiddleware] = [] - - def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware: - """ - Setup middleware - - :param middleware: - :param _stack_level: - :return: - """ - if not isinstance(middleware, AbstractMiddleware): - raise TypeError( - f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}" - ) - if middleware.configured: - if middleware.manager is self: - warn( - f"Middleware {middleware} is already configured for this Router " - "That's mean re-installing of this middleware has no effect.", - category=RuntimeWarning, - stacklevel=_stack_level + 1, - ) - return middleware - raise ValueError( - f"Middleware is already configured for another manager {middleware.manager} " - f"in router {middleware.manager.router}!" - ) - - self.middlewares.append(middleware) - middleware.setup(self) - return middleware - - async def trigger( - self, - step: MiddlewareStep, - event_name: str, - event: UpdateType, - data: Dict[str, Any], - result: Any = None, - reverse: bool = False, - ) -> Any: - """ - Call action to middlewares with args lilt. - """ - middlewares = reversed(self.middlewares) if reverse else self.middlewares - for middleware in middlewares: - await middleware.trigger( - step=step, event_name=event_name, event=event, data=data, result=result - ) - - def __contains__(self, item: AbstractMiddleware) -> bool: - return item in self.middlewares diff --git a/aiogram/dispatcher/middlewares/types.py b/aiogram/dispatcher/middlewares/types.py deleted file mode 100644 index bc173025..00000000 --- a/aiogram/dispatcher/middlewares/types.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import Union - -from aiogram.api.types import ( - CallbackQuery, - ChosenInlineResult, - InlineQuery, - Message, - Poll, - PollAnswer, - PreCheckoutQuery, - ShippingQuery, - Update, -) - -UpdateType = Union[ - CallbackQuery, - ChosenInlineResult, - InlineQuery, - Message, - Poll, - PollAnswer, - PreCheckoutQuery, - ShippingQuery, - Update, - BaseException, -] - - -class MiddlewareStep(Enum): - PRE_PROCESS = "pre_process" - PROCESS = "process" - POST_PROCESS = "post_process" diff --git a/aiogram/dispatcher/middlewares/user_context.py b/aiogram/dispatcher/middlewares/user_context.py new file mode 100644 index 00000000..09abd10f --- /dev/null +++ b/aiogram/dispatcher/middlewares/user_context.py @@ -0,0 +1,62 @@ +from contextlib import contextmanager +from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple + +from aiogram.api.types import Chat, Update, User +from aiogram.dispatcher.middlewares.base import BaseMiddleware + + +class UserContextMiddleware(BaseMiddleware[Update]): + async def __call__( + self, + handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]], + event: Update, + data: Dict[str, Any], + ) -> Any: + chat, user = self.resolve_event_context(event=event) + with self.context(chat=chat, user=user): + return await handler(event, data) + + @contextmanager + def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]: + chat_token = None + user_token = None + if chat: + chat_token = chat.set_current(chat) + if user: + user_token = user.set_current(user) + try: + yield + finally: + if chat and chat_token: + chat.reset_current(chat_token) + if user and user_token: + user.reset_current(user_token) + + @classmethod + def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]: + """ + Resolve chat and user instance from Update object + """ + if event.message: + return event.message.chat, event.message.from_user + if event.edited_message: + return event.edited_message.chat, event.edited_message.from_user + if event.channel_post: + return event.channel_post.chat, None + if event.edited_channel_post: + return event.edited_channel_post.chat, None + if event.inline_query: + return None, event.inline_query.from_user + if event.chosen_inline_result: + return None, event.chosen_inline_result.from_user + if event.callback_query: + if event.callback_query.message: + return event.callback_query.message.chat, event.callback_query.from_user + return None, event.callback_query.from_user + if event.shipping_query: + return None, event.shipping_query.from_user + if event.pre_checkout_query: + return None, event.pre_checkout_query.from_user + if event.poll_answer: + return None, event.poll_answer.user + return None, None diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 371c490d..7680454e 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -3,13 +3,14 @@ from __future__ import annotations import warnings from typing import Any, Dict, Generator, List, Optional, Union -from ..api.types import Chat, TelegramObject, Update, User +from ..api.types import TelegramObject, Update from ..utils.imports import import_module from ..utils.warnings import CodeHasNoEffect -from .event.observer import EventObserver, SkipHandler, TelegramEventObserver +from .event.bases import NOT_HANDLED, SkipHandler +from .event.event import EventObserver +from .event.telegram import TelegramEventObserver from .filters import BUILTIN_FILTERS -from .middlewares.abstract import AbstractMiddleware -from .middlewares.manager import MiddlewareManager +from .middlewares.error import ErrorsMiddleware class Router: @@ -44,8 +45,6 @@ class Router: self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer") self.errors = TelegramEventObserver(router=self, event_name="error") - self.middleware = MiddlewareManager(router=self) - self.startup = EventObserver() self.shutdown = EventObserver() @@ -68,6 +67,8 @@ class Router: # Root handler self.update.register(self._listen_update) + self.update.outer_middleware(ErrorsMiddleware(self)) + # Builtin filters if use_builtin_filters: for name, observer in self.observers.items(): @@ -94,16 +95,6 @@ class Router: next(tail) # Skip self yield from tail - def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware: - """ - Use middleware - - :param middleware: - :param _stack_level: - :return: - """ - return self.middleware.setup(middleware, _stack_level=_stack_level + 1) - @property def parent_router(self) -> Optional[Router]: return self._parent_router @@ -176,53 +167,40 @@ class Router: :param kwargs: :return: """ - chat: Optional[Chat] = None - from_user: Optional[User] = None - event: TelegramObject if update.message: update_type = "message" - from_user = update.message.from_user - chat = update.message.chat event = update.message elif update.edited_message: update_type = "edited_message" - from_user = update.edited_message.from_user - chat = update.edited_message.chat event = update.edited_message elif update.channel_post: update_type = "channel_post" - chat = update.channel_post.chat event = update.channel_post elif update.edited_channel_post: update_type = "edited_channel_post" - chat = update.edited_channel_post.chat event = update.edited_channel_post elif update.inline_query: update_type = "inline_query" - from_user = update.inline_query.from_user event = update.inline_query elif update.chosen_inline_result: update_type = "chosen_inline_result" - from_user = update.chosen_inline_result.from_user event = update.chosen_inline_result elif update.callback_query: update_type = "callback_query" - if update.callback_query.message: - chat = update.callback_query.message.chat - from_user = update.callback_query.from_user event = update.callback_query elif update.shipping_query: update_type = "shipping_query" - from_user = update.shipping_query.from_user event = update.shipping_query elif update.pre_checkout_query: update_type = "pre_checkout_query" - from_user = update.pre_checkout_query.from_user event = update.pre_checkout_query elif update.poll: update_type = "poll" event = update.poll + elif update.poll_answer: + update_type = "poll_answer" + event = update.poll_answer else: warnings.warn( "Detected unknown update type.\n" @@ -232,76 +210,17 @@ class Router: ) raise SkipHandler - return await self.listen_update( - update_type=update_type, - update=update, - event=event, - from_user=from_user, - chat=chat, - **kwargs, - ) - - async def listen_update( - self, - update_type: str, - update: Update, - event: TelegramObject, - from_user: Optional[User] = None, - chat: Optional[Chat] = None, - **kwargs: Any, - ) -> Any: - """ - Listen update by current and child routers - - :param update_type: - :param update: - :param event: - :param from_user: - :param chat: - :param kwargs: - :return: - """ - user_token = None - if from_user: - user_token = User.set_current(from_user) - chat_token = None - if chat: - chat_token = Chat.set_current(chat) - kwargs.update(event_update=update, event_router=self) observer = self.observers[update_type] - try: - async for result in observer.trigger(event, update=update, **kwargs): - return result + response = await observer.trigger(event, update=update, **kwargs) + if response is NOT_HANDLED: # Resolve nested routers for router in self.sub_routers: - try: - return await router.listen_update( - update_type=update_type, - update=update, - event=event, - from_user=from_user, - chat=chat, - **kwargs, - ) - except SkipHandler: + response = await router.update.trigger(event=update, **kwargs) + if response is NOT_HANDLED: continue - raise SkipHandler - - except SkipHandler: - raise - - except Exception as e: - async for result in self.errors.trigger(e, **kwargs): - return result - raise - - finally: - if user_token: - User.reset_current(user_token) - if chat_token: - Chat.reset_current(chat_token) + return response async def emit_startup(self, *args: Any, **kwargs: Any) -> None: """ @@ -312,8 +231,7 @@ class Router: :return: """ kwargs.update(router=self) - async for _ in self.startup.trigger(*args, **kwargs): # pragma: no cover - pass + await self.startup.trigger(*args, **kwargs) for router in self.sub_routers: await router.emit_startup(*args, **kwargs) @@ -326,8 +244,7 @@ class Router: :return: """ kwargs.update(router=self) - async for _ in self.shutdown.trigger(*args, **kwargs): # pragma: no cover - pass + await self.shutdown.trigger(*args, **kwargs) for router in self.sub_routers: await router.emit_shutdown(*args, **kwargs) diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py index 0c4834f4..156339d6 100644 --- a/aiogram/utils/mixins.py +++ b/aiogram/utils/mixins.py @@ -1,9 +1,10 @@ from __future__ import annotations import contextvars -from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload -from typing_extensions import Literal +if TYPE_CHECKING: # pragma: no cover + from typing_extensions import Literal __all__ = ("ContextInstanceMixin", "DataMixin") diff --git a/docs/assets/images/basics_middleware.png b/docs/assets/images/basics_middleware.png index a797fd38..b4165e2e 100644 Binary files a/docs/assets/images/basics_middleware.png and b/docs/assets/images/basics_middleware.png differ diff --git a/docs/dispatcher/dispatcher.md b/docs/dispatcher/dispatcher.md index 462d0748..74f018e4 100644 --- a/docs/dispatcher/dispatcher.md +++ b/docs/dispatcher/dispatcher.md @@ -39,7 +39,7 @@ dp.include_router(router1) ## Handling updates All updates can be propagated to the dispatcher by `feed_update` method: -``` +```python3 bot = Bot(...) dp = Dispathcher() diff --git a/docs/dispatcher/middlewares.md b/docs/dispatcher/middlewares.md new file mode 100644 index 00000000..a0649ce5 --- /dev/null +++ b/docs/dispatcher/middlewares.md @@ -0,0 +1,95 @@ +# Middlewares + +**aiogram** provides powerful mechanism for customizing event handlers via middlewares. + +Middlewares in bot framework seems like Middlewares mechanism in web-frameworks +(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares), +[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/), +[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.) +with small difference - here is implemented two layers of middlewares (before and after filters). + +!!! info + Middleware is function that triggered on every event received from + Telegram Bot API in many points on processing pipeline. + +## Base theory + +As many books and other literature in internet says: +> Middleware is reusable software that leverages patterns and frameworks to bridge +> the gap between the functional requirements of applications and the underlying operating systems, +> network protocol stacks, and databases. + +Middleware can modify, extend or reject processing event in many places of pipeline. + +## Basics + +Middleware instance can be applied for every type of Telegram Event (Update, Message, etc.) in two places + +1. Outer scope - before processing filters (`#!python ..outer_middleware(...)`) +2. Inner scope - after processing filters but before handler (`#!python ..middleware(...)`) + +[![middlewares](../assets/images/basics_middleware.png)](../assets/images/basics_middleware.png) + +_(Click on image to zoom it)_ + +!!! warning + + Middleware should be subclass of `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) or any async callable + +## Arguments specification +| Argument | Type | Description | +| - | - | - | +| `handler` | `#!python Callable[[T, Dict[str, Any]], Awaitable[Any]]` | Wrapped handler in middlewares chain | +| `event` | `#!python T` | Incoming event (Subclass of `TelegramObject`) | +| `data` | `#!python Dict[str, Any]` | Contextual data. Will be mapped to handler arguments | + +## Examples + +!!! danger + + Middleware should always call `#!python await handler(event, data)` to propagate event for next middleware/handler + +### Class-based +```python3 +from aiogram import BaseMiddleware +from aiogram.api.types import Message + + +class CounterMiddleware(BaseMiddleware[Message]): + def __init__(self) -> None: + self.counter = 0 + + async def __call__( + self, + handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]], + event: Message, + data: Dict[str, Any] + ) -> Any: + self.counter += 1 + data['counter'] = self.counter + return await handler(event, data) +``` +and then +```python3 +router = Router() +router.message.middleware(CounterMiddleware()) +``` + +### Function-based +```python3 +@dispatcher.update.outer_middleware() +async def database_transaction_middleware( + handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]], + event: Update, + data: Dict[str, Any] +) -> Any: + async with database.transaction(): + return await handler(event, data) +``` + +## Facts + +1. Middlewares from outer scope will be called on every incoming event +1. Middlewares from inner scope will be called only when filters pass +1. Inner middlewares is always calls for `Update` event type in due to all incoming updates going to specific event type handler through built in update handler + diff --git a/docs/dispatcher/middlewares/basics.md b/docs/dispatcher/middlewares/basics.md deleted file mode 100644 index 83b58f07..00000000 --- a/docs/dispatcher/middlewares/basics.md +++ /dev/null @@ -1,115 +0,0 @@ -# Basics - -All middlewares should be made with `BaseMiddleware` (`#!python3 from aiogram import BaseMiddleware`) as base class. - -For example: - -```python3 -class MyMiddleware(BaseMiddleware): ... -``` - -And then use next pattern in naming callback functions in middleware: `on_{step}_{event}` - -Where is: - -- `#!python3 step`: - - `#!python3 pre_process` - - `#!python3 process` - - `#!python3 post_process` -- `#!python3 event`: - - `#!python3 update` - - `#!python3 message` - - `#!python3 edited_message` - - `#!python3 channel_post` - - `#!python3 edited_channel_post` - - `#!python3 inline_query` - - `#!python3 chosen_inline_result` - - `#!python3 callback_query` - - `#!python3 shipping_query` - - `#!python3 pre_checkout_query` - - `#!python3 poll` - - `#!python3 poll_answer` - - `#!python3 error` - -## Connecting middleware with router - -Middlewares can be connected with router by next ways: - -1. `#!python3 router.use(MyMiddleware())` (**recommended**) -1. `#!python3 router.middleware.setup(MyMiddleware())` -1. `#!python3 MyMiddleware().setup(router.middleware)` (**not recommended**) - -!!! warning - One instance of middleware **can't** be registered twice in single or many middleware managers - -## The specification of step callbacks - -### Pre-process step - -| Argument | Type | Description | -| --- | --- | --- | -| event name | Any of event type (Update, Message and etc.) | Event | -| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | - -Returns `#!python3 Any` - -### Process step - -| Argument | Type | Description | -| --- | --- | --- | -| event name | Any of event type (Update, Message and etc.) | Event | -| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | - -Returns `#!python3 Any` - -### Post-Process step - -| Argument | Type | Description | -| --- | --- | --- | -| event name | Any of event type (Update, Message and etc.) | Event | -| `#!python3 data` | `#!python3 Dict[str, Any]` | Contextual data (Will be mapped to handler arguments) | -| `#!python3 result` | `#!python3 Dict[str, Any]` | Response from handlers | - -Returns `#!python3 Any` - -## Full list of available callbacks - -- `#!python3 on_pre_process_update` - will be triggered on **pre process** `#!python3 update` event -- `#!python3 on_process_update` - will be triggered on **process** `#!python3 update` event -- `#!python3 on_post_process_update` - will be triggered on **post process** `#!python3 update` event -- `#!python3 on_pre_process_message` - will be triggered on **pre process** `#!python3 message` event -- `#!python3 on_process_message` - will be triggered on **process** `#!python3 message` event -- `#!python3 on_post_process_message` - will be triggered on **post process** `#!python3 message` event -- `#!python3 on_pre_process_edited_message` - will be triggered on **pre process** `#!python3 edited_message` event -- `#!python3 on_process_edited_message` - will be triggered on **process** `#!python3 edited_message` event -- `#!python3 on_post_process_edited_message` - will be triggered on **post process** `#!python3 edited_message` event -- `#!python3 on_pre_process_channel_post` - will be triggered on **pre process** `#!python3 channel_post` event -- `#!python3 on_process_channel_post` - will be triggered on **process** `#!python3 channel_post` event -- `#!python3 on_post_process_channel_post` - will be triggered on **post process** `#!python3 channel_post` event -- `#!python3 on_pre_process_edited_channel_post` - will be triggered on **pre process** `#!python3 edited_channel_post` event -- `#!python3 on_process_edited_channel_post` - will be triggered on **process** `#!python3 edited_channel_post` event -- `#!python3 on_post_process_edited_channel_post` - will be triggered on **post process** `#!python3 edited_channel_post` event -- `#!python3 on_pre_process_inline_query` - will be triggered on **pre process** `#!python3 inline_query` event -- `#!python3 on_process_inline_query` - will be triggered on **process** `#!python3 inline_query` event -- `#!python3 on_post_process_inline_query` - will be triggered on **post process** `#!python3 inline_query` event -- `#!python3 on_pre_process_chosen_inline_result` - will be triggered on **pre process** `#!python3 chosen_inline_result` event -- `#!python3 on_process_chosen_inline_result` - will be triggered on **process** `#!python3 chosen_inline_result` event -- `#!python3 on_post_process_chosen_inline_result` - will be triggered on **post process** `#!python3 chosen_inline_result` event -- `#!python3 on_pre_process_callback_query` - will be triggered on **pre process** `#!python3 callback_query` event -- `#!python3 on_process_callback_query` - will be triggered on **process** `#!python3 callback_query` event -- `#!python3 on_post_process_callback_query` - will be triggered on **post process** `#!python3 callback_query` event -- `#!python3 on_pre_process_shipping_query` - will be triggered on **pre process** `#!python3 shipping_query` event -- `#!python3 on_process_shipping_query` - will be triggered on **process** `#!python3 shipping_query` event -- `#!python3 on_post_process_shipping_query` - will be triggered on **post process** `#!python3 shipping_query` event -- `#!python3 on_pre_process_pre_checkout_query` - will be triggered on **pre process** `#!python3 pre_checkout_query` event -- `#!python3 on_process_pre_checkout_query` - will be triggered on **process** `#!python3 pre_checkout_query` event -- `#!python3 on_post_process_pre_checkout_query` - will be triggered on **post process** `#!python3 pre_checkout_query` event -- `#!python3 on_pre_process_poll` - will be triggered on **pre process** `#!python3 poll` event -- `#!python3 on_process_poll` - will be triggered on **process** `#!python3 poll` event -- `#!python3 on_post_process_poll` - will be triggered on **post process** `#!python3 poll` event -- `#!python3 on_pre_process_poll_answer` - will be triggered on **pre process** `#!python3 poll_answer` event -- `#!python3 on_process_poll_answer` - will be triggered on **process** `#!python3 poll_answer` event -- `#!python3 on_post_process_poll_answer` - will be triggered on **post process** `#!python3 poll_answer` event -- `#!python3 on_pre_process_error` - will be triggered on **pre process** `#!python3 error` event -- `#!python3 on_process_error` - will be triggered on **process** `#!python3 error` event -- `#!python3 on_post_process_error` - will be triggered on **post process** `#!python3 error` event diff --git a/docs/dispatcher/middlewares/index.md b/docs/dispatcher/middlewares/index.md deleted file mode 100644 index 6815a565..00000000 --- a/docs/dispatcher/middlewares/index.md +++ /dev/null @@ -1,77 +0,0 @@ -# Overview - -**aiogram** provides powerful mechanism for customizing event handlers via middlewares. - -Middlewares in bot framework seems like Middlewares mechanism in web-frameworks -(like [aiohttp](https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-middlewares), -[fastapi](https://fastapi.tiangolo.com/tutorial/middleware/), -[Django](https://docs.djangoproject.com/en/3.0/topics/http/middleware/) or etc.) -with small difference - here is implemented many layers of processing -(named as [pipeline](#event-pipeline)). - -!!! info - Middleware is function that triggered on every event received from - Telegram Bot API in many points on processing pipeline. - -## Base theory - -As many books and other literature in internet says: -> Middleware is reusable software that leverages patterns and frameworks to bridge ->the gap between the functional requirements of applications and the underlying operating systems, -> network protocol stacks, and databases. - -Middleware can modify, extend or reject processing event before-, -on- or after- processing of that event. - -[![middlewares](../../assets/images/basics_middleware.png)](../../assets/images/basics_middleware.png) - -_(Click on image to zoom it)_ - -## Event pipeline - -As described below middleware an interact with event in many stages of pipeline. - -Simple workflow: - -1. Dispatcher receive an [Update](../../api/types/update.md) -1. Call **pre-process** update middleware in all routers tree -1. Filter Update over handlers -1. Call **process** update middleware in all routers tree -1. Router detects event type (Message, Callback query, etc.) -1. Router triggers **pre-process** middleware of specific type -1. Pass event over [filters](../filters/index.md) to detect specific handler -1. Call **process** middleware for specific type (only when handler for this event exists) -1. *Do magick*. Call handler (Read more [Event observers](../router.md#event-observers)) -1. Call **post-process** middleware -1. Call **post-process** update middleware in all routers tree -1. Emit response into webhook (when it needed) - -!!! warning - When filters does not match any handler with this event the `#!python3 process` - step will not be called. - -!!! warning - When exception will be caused in handlers pipeline will be stopped immediately - and then start processing error via errors handler and it own middleware callbacks. - -!!! warning - Middlewares for updates will be called for all routers in tree but callbacks for events - will be called only for specific branch of routers. - -### Pipeline in pictures: - -#### Simple pipeline - -[![middlewares](../../assets/images/middleware_pipeline.png)](../../assets/images/middleware_pipeline.png) - -_(Click on image to zoom it)_ - -#### Nested routers pipeline - -[![middlewares](../../assets/images/middleware_pipeline_nested.png)](../../assets/images/middleware_pipeline_nested.png) - -_(Click on image to zoom it)_ - -## Read more - -- [Middleware Basics](basics.md) diff --git a/docs/index.md b/docs/index.md index 9529d0ac..5f6b5bce 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ Documentation for version 3.0 [WIP] [^1] - [Supports Telegram Bot API v{!_api_version.md!}](api/index.md) - [Updates router](dispatcher/index.md) (Blueprints) - Finite State Machine -- [Middlewares](dispatcher/middlewares/index.md) +- [Middlewares](dispatcher/middlewares.md) - [Replies into Webhook](https://core.telegram.org/bots/faq#how-can-i-make-requests-in-response-to-updates) diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css new file mode 100644 index 00000000..b35d291c --- /dev/null +++ b/docs/stylesheets/extra.css @@ -0,0 +1,12 @@ +@font-face { + font-family: 'JetBrainsMono'; + src: url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff2/JetBrainsMono-Regular.woff2') format('woff2'), + url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/web/woff/JetBrainsMono-Regular.woff') format('woff'), + url('https://cdn.jsdelivr.net/gh/JetBrains/JetBrainsMono/ttf/JetBrainsMono-Regular.ttf') format('truetype'); + font-weight: 400; + font-style: normal; +} + +code, kbd, pre { + font-family: "JetBrainsMono", "Roboto Mono", "Courier New", Courier, monospace; +} diff --git a/mkdocs.yml b/mkdocs.yml index 527d0561..2c50c6db 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -16,6 +16,9 @@ theme: favicon: 'assets/images/logo.png' logo: 'assets/images/logo.png' +extra_css: + - stylesheets/extra.css + extra: version: 3.0.0a3 @@ -255,9 +258,8 @@ nav: - dispatcher/class_based_handlers/pre_checkout_query.md - dispatcher/class_based_handlers/shipping_query.md - dispatcher/class_based_handlers/error.md - - Middlewares: - - dispatcher/middlewares/index.md - - dispatcher/middlewares/basics.md + - dispatcher/middlewares.md + - todo.md - Build reports: - reports.md diff --git a/poetry.lock b/poetry.lock index ba3c3590..cdd78f22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -884,7 +884,7 @@ category = "dev" description = "YAML parser and emitter for Python" name = "pyyaml" optional = false -python-versions = "*" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" version = "5.3.1" [[package]] @@ -955,7 +955,7 @@ python-versions = "*" version = "1.4.1" [[package]] -category = "main" +category = "dev" description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" optional = false @@ -1031,8 +1031,7 @@ fast = ["uvloop"] proxy = ["aiohttp-socks"] [metadata] -content-hash = "57137b60a539ba01e8df533db976e2f3eadec37e717cbefbe775dc021a8c2714" - +content-hash = "768759359beca8b84811bfc21adac9649925cd22b87427a10608c9d1e16a0923" python-versions = "^3.7" [metadata.files] @@ -1284,11 +1283,6 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] mccabe = [ diff --git a/pyproject.toml b/pyproject.toml index 60a9c0e3..807bab18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ aiofiles = "^0.4.0" uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true} async_lru = "^1.0" aiohttp-socks = {version = "^0.3.8", optional = true} -typing-extensions = "^3.7.4" [tool.poetry.dev-dependencies] uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'"} @@ -68,6 +67,7 @@ markdown-include = "^0.5.1" aiohttp-socks = "^0.3.4" pre-commit = "^2.3.0" packaging = "^20.3" +typing-extensions = "^3.7.4" [tool.poetry.extras] fast = ["uvloop"] @@ -94,7 +94,7 @@ include_trailing_comma = true force_grid_wrap = 0 use_parentheses = true line_length = 99 -known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pydantic", "pytest", "typing_extensions"] +known_third_party = ["aiofiles", "aiohttp", "aiohttp_socks", "aresponses", "async_lru", "packaging", "pkg_resources", "pydantic", "pytest"] [build-system] requires = ["poetry>=0.12"] diff --git a/tests/test_api/test_types/test_inline_query.py b/tests/test_api/test_types/test_inline_query.py index 2ef76685..c828c17a 100644 --- a/tests/test_api/test_types/test_inline_query.py +++ b/tests/test_api/test_types/test_inline_query.py @@ -11,7 +11,13 @@ class TestInlineQuery: offset="", ) - kwargs = dict(results=[], cache_time=123, next_offset="123", switch_pm_text="foo", switch_pm_parameter="foo") + kwargs = dict( + results=[], + cache_time=123, + next_offset="123", + switch_pm_text="foo", + switch_pm_parameter="foo", + ) api_method = inline_query.answer(**kwargs) diff --git a/tests/test_api/test_types/test_shipping_query.py b/tests/test_api/test_types/test_shipping_query.py index 94e60640..939bb6c5 100644 --- a/tests/test_api/test_types/test_shipping_query.py +++ b/tests/test_api/test_types/test_shipping_query.py @@ -1,5 +1,5 @@ from aiogram.api.methods import AnswerShippingQuery -from aiogram.api.types import ShippingAddress, ShippingQuery, User, ShippingOption, LabeledPrice +from aiogram.api.types import LabeledPrice, ShippingAddress, ShippingOption, ShippingQuery, User class TestInlineQuery: @@ -19,7 +19,8 @@ class TestInlineQuery: ) shipping_options = [ - ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])] + ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)]) + ] kwargs = dict(ok=True, shipping_options=shipping_options, error_message="foo") diff --git a/tests/test_dispatcher/test_deprecated.py b/tests/test_dispatcher/test_deprecated.py index efe6dfda..edea4aac 100644 --- a/tests/test_dispatcher/test_deprecated.py +++ b/tests/test_dispatcher/test_deprecated.py @@ -1,6 +1,6 @@ import pytest -from aiogram.dispatcher.event.observer import TelegramEventObserver +from aiogram.dispatcher.event.telegram import TelegramEventObserver from aiogram.dispatcher.router import Router from tests.deprecated import check_deprecated diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 46a53db7..e5b9b50f 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -9,6 +9,7 @@ from aiogram import Bot from aiogram.api.methods import GetMe, GetUpdates, SendMessage from aiogram.api.types import Chat, Message, Update, User from aiogram.dispatcher.dispatcher import Dispatcher +from aiogram.dispatcher.event.bases import NOT_HANDLED from aiogram.dispatcher.router import Router from tests.mocked_bot import MockedBot @@ -63,7 +64,7 @@ class TestDispatcher: return message.text results_count = 0 - async for result in dp.feed_update( + result = await dp.feed_update( bot=bot, update=Update( update_id=42, @@ -75,11 +76,9 @@ class TestDispatcher: from_user=User(id=42, is_bot=False, first_name="Test"), ), ), - ): - results_count += 1 - assert result == "test" - - assert results_count == 1 + ) + results_count += 1 + assert result == "test" @pytest.mark.asyncio async def test_feed_raw_update(self): @@ -91,8 +90,7 @@ class TestDispatcher: assert message.text == "test" return message.text - handled = False - async for result in dp.feed_raw_update( + result = await dp.feed_raw_update( bot=bot, update={ "update_id": 42, @@ -101,13 +99,11 @@ class TestDispatcher: "date": int(time.time()), "text": "test", "chat": {"id": 42, "type": "private"}, - "user": {"id": 42, "is_bot": False, "first_name": "Test"}, + "from": {"id": 42, "is_bot": False, "first_name": "Test"}, }, }, - ): - handled = True - assert result == "test" - assert handled + ) + assert result == "test" @pytest.mark.asyncio async def test_listen_updates(self, bot: MockedBot): @@ -136,7 +132,8 @@ class TestDispatcher: async def test_process_update_empty(self, bot: MockedBot): dispatcher = Dispatcher() - assert not await dispatcher.process_update(bot=bot, update=Update(update_id=42)) + result = await dispatcher._process_update(bot=bot, update=Update(update_id=42)) + assert result @pytest.mark.asyncio async def test_process_update_handled(self, bot: MockedBot): @@ -146,22 +143,25 @@ class TestDispatcher: async def update_handler(update: Update): pass - assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) + assert await dispatcher._process_update(bot=bot, update=Update(update_id=42)) @pytest.mark.asyncio async def test_process_update_call_request(self, bot: MockedBot): dispatcher = Dispatcher() @dispatcher.update() - async def update_handler(update: Update): + async def message_handler(update: Update): return GetMe() + dispatcher.update.handlers.reverse() + with patch( "aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request", new_callable=CoroutineMock, ) as mocked_silent_call_request: - assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) - mocked_silent_call_request.assert_awaited_once() + result = await dispatcher._process_update(bot=bot, update=Update(update_id=42)) + print(result) + mocked_silent_call_request.assert_awaited() @pytest.mark.asyncio async def test_process_update_exception(self, bot: MockedBot, caplog): @@ -171,7 +171,7 @@ class TestDispatcher: async def update_handler(update: Update): raise Exception("Kaboom!") - assert await dispatcher.process_update(bot=bot, update=Update(update_id=42)) + assert await dispatcher._process_update(bot=bot, update=Update(update_id=42)) log_records = [rec.message for rec in caplog.records] assert len(log_records) == 1 assert "Cause exception while process update" in log_records[0] @@ -184,7 +184,7 @@ class TestDispatcher: yield Update(update_id=42) with patch( - "aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock + "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock ) as mocked_process_update, patch( "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates" ) as patched_listen_updates: @@ -203,7 +203,7 @@ class TestDispatcher: yield Update(update_id=42) with patch( - "aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock + "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock ) as mocked_process_update, patch( "aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock ) as mocked_emit_startup, patch( diff --git a/tests/test_dispatcher/test_event/test_event.py b/tests/test_dispatcher/test_event/test_event.py new file mode 100644 index 00000000..335b864c --- /dev/null +++ b/tests/test_dispatcher/test_event/test_event.py @@ -0,0 +1,59 @@ +import functools +from typing import Any + +import pytest + +from aiogram.dispatcher.event.event import EventObserver +from aiogram.dispatcher.event.handler import HandlerObject + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +async def my_handler(value: str, index: int = 0) -> Any: + return value + + +class TestEventObserver: + @pytest.mark.parametrize("via_decorator", [True, False]) + @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) + def test_register_filters(self, via_decorator, count, handler): + observer = EventObserver() + + for index in range(count): + wrapped_handler = functools.partial(handler, index=index) + if via_decorator: + register_result = observer()(wrapped_handler) + assert register_result == wrapped_handler + else: + register_result = observer.register(wrapped_handler) + assert register_result is None + + registered_handler = observer.handlers[index] + + assert len(observer.handlers) == index + 1 + assert isinstance(registered_handler, HandlerObject) + assert registered_handler.callback == wrapped_handler + assert not registered_handler.filters + + @pytest.mark.asyncio + async def test_trigger(self): + observer = EventObserver() + + observer.register(my_handler) + observer.register(lambda e: True) + observer.register(my_handler) + + assert observer.handlers[0].awaitable + assert not observer.handlers[1].awaitable + assert observer.handlers[2].awaitable + + with patch( + "aiogram.dispatcher.event.handler.CallableMixin.call", new_callable=CoroutineMock, + ) as mocked_my_handler: + results = await observer.trigger("test") + assert results is None + mocked_my_handler.assert_awaited_with("test") + assert mocked_my_handler.call_count == 3 diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_telegram.py similarity index 72% rename from tests/test_dispatcher/test_event/test_observer.py rename to tests/test_dispatcher/test_event/test_telegram.py index c1364676..5d4c6607 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -5,11 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, NoReturn, Union import pytest from aiogram.api.types import Chat, Message, User +from aiogram.dispatcher.event.bases import SkipHandler from aiogram.dispatcher.event.handler import HandlerObject -from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver +from aiogram.dispatcher.event.telegram import TelegramEventObserver from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.router import Router +# TODO: Test middlewares in routers tree + async def my_handler(event: Any, index: int = 0) -> Any: return event @@ -38,54 +41,6 @@ class MyFilter3(MyFilter1): pass -class TestEventObserver: - @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) - def test_register_filters(self, count, handler): - observer = EventObserver() - - for index in range(count): - wrapped_handler = functools.partial(handler, index=index) - observer.register(wrapped_handler) - registered_handler = observer.handlers[index] - - assert len(observer.handlers) == index + 1 - assert isinstance(registered_handler, HandlerObject) - assert registered_handler.callback == wrapped_handler - assert not registered_handler.filters - - @pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler])) - def test_register_filters_via_decorator(self, count, handler): - observer = EventObserver() - - for index in range(count): - wrapped_handler = functools.partial(handler, index=index) - observer()(wrapped_handler) - registered_handler = observer.handlers[index] - - assert len(observer.handlers) == index + 1 - assert isinstance(registered_handler, HandlerObject) - assert registered_handler.callback == wrapped_handler - assert not registered_handler.filters - - @pytest.mark.asyncio - async def test_trigger_accepted_bool(self): - observer = EventObserver() - observer.register(my_handler) - - results = [result async for result in observer.trigger(42)] - assert results == [42] - - @pytest.mark.asyncio - async def test_trigger_with_skip(self): - observer = EventObserver() - observer.register(skip_my_handler) - observer.register(my_handler) - observer.register(my_handler) - - results = [result async for result in observer.trigger(42)] - assert results == [42, 42] - - class TestTelegramEventObserver: def test_bind_filter(self): event_observer = TelegramEventObserver(Router(), "test") @@ -198,8 +153,8 @@ class TestTelegramEventObserver: from_user=User(id=42, is_bot=False, first_name="Test"), ) - results = [result async for result in observer.trigger(message)] - assert results == [message] + results = await observer.trigger(message) + assert results is message @pytest.mark.parametrize( "count,handler,filters", @@ -223,15 +178,58 @@ class TestTelegramEventObserver: assert registered_handler.callback == wrapped_handler assert len(registered_handler.filters) == len(filters) - # @pytest.mark.asyncio async def test_trigger_right_context_in_handlers(self): router = Router(use_builtin_filters=False) observer = router.message - observer.register( - pipe_handler, lambda event: {"a": 1}, lambda event: False - ) # {"a": 1} should not be in result - observer.register(pipe_handler, lambda event: {"b": 2}) - results = [result async for result in observer.trigger(42)] - assert results == [((42,), {"b": 2})] + async def mix_unnecessary_data(event): + return {"a": 1} + + async def mix_data(event): + return {"b": 2} + + async def handler(event, **kwargs): + return False + + observer.register( + pipe_handler, mix_unnecessary_data, handler + ) # {"a": 1} should not be in result + observer.register(pipe_handler, mix_data) + + results = await observer.trigger(42) + assert results == ((42,), {"b": 2}) + + @pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware")) + def test_register_middleware(self, middleware_type): + event_observer = TelegramEventObserver(Router(), "test") + + middlewares = getattr(event_observer, f"{middleware_type}s") + decorator = getattr(event_observer, middleware_type) + + @decorator + async def my_middleware1(handler, event, data): + pass + + assert my_middleware1 is not None + assert my_middleware1.__name__ == "my_middleware1" + assert my_middleware1 in middlewares + + @decorator() + async def my_middleware2(handler, event, data): + pass + + assert my_middleware2 is not None + assert my_middleware2.__name__ == "my_middleware2" + assert my_middleware2 in middlewares + + async def my_middleware3(handler, event, data): + pass + + decorator(my_middleware3) + + assert my_middleware3 is not None + assert my_middleware3.__name__ == "my_middleware3" + assert my_middleware3 in middlewares + + assert middlewares == [my_middleware1, my_middleware2, my_middleware3] diff --git a/tests/test_dispatcher/test_middlewares/__init__.py b/tests/test_dispatcher/test_middlewares/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_dispatcher/test_middlewares/test_base.py b/tests/test_dispatcher/test_middlewares/test_base.py deleted file mode 100644 index 7899324d..00000000 --- a/tests/test_dispatcher/test_middlewares/test_base.py +++ /dev/null @@ -1,257 +0,0 @@ -import datetime -from typing import Any, Dict, Type - -import pytest - -from aiogram.api.types import ( - CallbackQuery, - Chat, - ChosenInlineResult, - InlineQuery, - Message, - Poll, - PollAnswer, - PreCheckoutQuery, - ShippingQuery, - Update, - User, -) -from aiogram.dispatcher.middlewares.base import BaseMiddleware -from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType - -try: - from asynctest import CoroutineMock, patch -except ImportError: - from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore - - -class MyMiddleware(BaseMiddleware): - async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any: - return "update" - - async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any: - return "message" - - async def on_pre_process_edited_message( - self, edited_message: Message, data: Dict[str, Any] - ) -> Any: - return "edited_message" - - async def on_pre_process_channel_post( - self, channel_post: Message, data: Dict[str, Any] - ) -> Any: - return "channel_post" - - async def on_pre_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any] - ) -> Any: - return "edited_channel_post" - - async def on_pre_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any] - ) -> Any: - return "inline_query" - - async def on_pre_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] - ) -> Any: - return "chosen_inline_result" - - async def on_pre_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any] - ) -> Any: - return "callback_query" - - async def on_pre_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any] - ) -> Any: - return "shipping_query" - - async def on_pre_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] - ) -> Any: - return "pre_checkout_query" - - async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: - return "poll" - - async def on_pre_process_poll_answer( - self, poll_answer: PollAnswer, data: Dict[str, Any] - ) -> Any: - return "poll_answer" - - async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: - return "error" - - async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any: - return "update" - - async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any: - return "message" - - async def on_process_edited_message( - self, edited_message: Message, data: Dict[str, Any] - ) -> Any: - return "edited_message" - - async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any: - return "channel_post" - - async def on_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any] - ) -> Any: - return "edited_channel_post" - - async def on_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any] - ) -> Any: - return "inline_query" - - async def on_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any] - ) -> Any: - return "chosen_inline_result" - - async def on_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any] - ) -> Any: - return "callback_query" - - async def on_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any] - ) -> Any: - return "shipping_query" - - async def on_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any] - ) -> Any: - return "pre_checkout_query" - - async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any: - return "poll" - - async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any: - return "poll_answer" - - async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: - return "error" - - async def on_post_process_update( - self, update: Update, data: Dict[str, Any], result: Any - ) -> Any: - return "update" - - async def on_post_process_message( - self, message: Message, data: Dict[str, Any], result: Any - ) -> Any: - return "message" - - async def on_post_process_edited_message( - self, edited_message: Message, data: Dict[str, Any], result: Any - ) -> Any: - return "edited_message" - - async def on_post_process_channel_post( - self, channel_post: Message, data: Dict[str, Any], result: Any - ) -> Any: - return "channel_post" - - async def on_post_process_edited_channel_post( - self, edited_channel_post: Message, data: Dict[str, Any], result: Any - ) -> Any: - return "edited_channel_post" - - async def on_post_process_inline_query( - self, inline_query: InlineQuery, data: Dict[str, Any], result: Any - ) -> Any: - return "inline_query" - - async def on_post_process_chosen_inline_result( - self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any - ) -> Any: - return "chosen_inline_result" - - async def on_post_process_callback_query( - self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any - ) -> Any: - return "callback_query" - - async def on_post_process_shipping_query( - self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any - ) -> Any: - return "shipping_query" - - async def on_post_process_pre_checkout_query( - self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any - ) -> Any: - return "pre_checkout_query" - - async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any: - return "poll" - - async def on_post_process_poll_answer( - self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any - ) -> Any: - return "poll_answer" - - async def on_post_process_error( - self, exception: Exception, data: Dict[str, Any], result: Any - ) -> Any: - return "error" - - -UPDATE = Update(update_id=42) -MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private")) -POLL_ANSWER = PollAnswer( - poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0] -) - - -class TestBaseMiddleware: - @pytest.mark.asyncio - @pytest.mark.parametrize( - "middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]] - ) - @pytest.mark.parametrize( - "step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS] - ) - @pytest.mark.parametrize( - "event_name,event", - [ - ["update", UPDATE], - ["message", MESSAGE], - ["poll_answer", POLL_ANSWER], - ["error", Exception("KABOOM")], - ], - ) - async def test_trigger( - self, - step: MiddlewareStep, - event_name: str, - event: UpdateType, - middleware_cls: Type[BaseMiddleware], - should_be_awaited: bool, - ): - middleware = middleware_cls() - - with patch( - f"tests.test_dispatcher.test_middlewares.test_base." - f"MyMiddleware.on_{step.value}_{event_name}", - new_callable=CoroutineMock, - ) as mocked_call: - response = await middleware.trigger( - step=step, event_name=event_name, event=event, data={} - ) - if should_be_awaited: - mocked_call.assert_awaited() - assert response is not None - else: - mocked_call.assert_not_awaited() - assert response is None - - def test_not_configured(self): - middleware = BaseMiddleware() - assert not middleware.configured - - with pytest.raises(RuntimeError): - manager = middleware.manager diff --git a/tests/test_dispatcher/test_middlewares/test_manager.py b/tests/test_dispatcher/test_middlewares/test_manager.py deleted file mode 100644 index 0e23f1b2..00000000 --- a/tests/test_dispatcher/test_middlewares/test_manager.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest - -from aiogram import Router -from aiogram.api.types import Update -from aiogram.dispatcher.middlewares.base import BaseMiddleware -from aiogram.dispatcher.middlewares.manager import MiddlewareManager -from aiogram.dispatcher.middlewares.types import MiddlewareStep - -try: - from asynctest import CoroutineMock, patch -except ImportError: - from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore - - -@pytest.fixture("function") -def router(): - return Router() - - -@pytest.fixture("function") -def manager(router: Router): - return MiddlewareManager(router) - - -class TestManager: - def test_setup(self, manager: MiddlewareManager): - middleware = BaseMiddleware() - returned = manager.setup(middleware) - assert returned is middleware - assert middleware.configured - assert middleware.manager is manager - assert middleware in manager - - @pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware]) - def test_setup_invalid_type(self, manager: MiddlewareManager, obj): - with pytest.raises(TypeError): - assert manager.setup(obj) - - def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router): - middleware = BaseMiddleware() - manager.setup(middleware) - - assert middleware.configured - - new_manager = MiddlewareManager(router) - with pytest.raises(ValueError): - new_manager.setup(middleware) - with pytest.raises(ValueError): - middleware.setup(new_manager) - - def test_configure_twice(self, manager: MiddlewareManager): - middleware = BaseMiddleware() - manager.setup(middleware) - - assert middleware.configured - - with pytest.warns(RuntimeWarning, match="is already configured for this Router"): - manager.setup(middleware) - - with pytest.warns(RuntimeWarning, match="is already configured for this Router"): - middleware.setup(manager) - - @pytest.mark.asyncio - @pytest.mark.parametrize("count", range(5)) - async def test_trigger(self, manager: MiddlewareManager, count: int): - for _ in range(count): - manager.setup(BaseMiddleware()) - - with patch( - "aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger", - new_callable=CoroutineMock, - ) as mocked_call: - await manager.trigger( - step=MiddlewareStep.PROCESS, - event_name="update", - event=Update(update_id=42), - data={}, - result=None, - reverse=True, - ) - - assert mocked_call.await_count == count diff --git a/tests/test_dispatcher/test_router.py b/tests/test_dispatcher/test_router.py index 303efbb3..9d425388 100644 --- a/tests/test_dispatcher/test_router.py +++ b/tests/test_dispatcher/test_router.py @@ -10,6 +10,7 @@ from aiogram.api.types import ( InlineQuery, Message, Poll, + PollAnswer, PollOption, PreCheckoutQuery, ShippingAddress, @@ -17,8 +18,8 @@ from aiogram.api.types import ( Update, User, ) -from aiogram.dispatcher.event.observer import SkipHandler, skip -from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.dispatcher.event.bases import NOT_HANDLED, SkipHandler, skip +from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.router import Router from aiogram.utils.warnings import CodeHasNoEffect @@ -274,12 +275,26 @@ class TestRouter: False, False, ), + pytest.param( + "poll_answer", + Update( + update_id=42, + poll_answer=PollAnswer( + poll_id="poll id", + user=User(id=42, is_bot=False, first_name="Test"), + option_ids=[42], + ), + ), + False, + True, + ), ], ) async def test_listen_update( self, event_type: str, update: Update, has_chat: bool, has_user: bool ): router = Router() + router.update.outer_middleware(UserContextMiddleware()) observer = router.observers[event_type] @observer() @@ -291,7 +306,7 @@ class TestRouter: assert User.get_current(False) return kwargs - result = await router._listen_update(update, test="PASS") + result = await router.update.trigger(update, test="PASS") assert isinstance(result, dict) assert result["event_update"] == update assert result["event_router"] == router @@ -313,26 +328,26 @@ class TestRouter: async def handler(event: Any): pass - with pytest.raises(SkipHandler): - await router._listen_update( - Update( - update_id=42, - poll=Poll( - id="poll id", - question="Q?", - options=[ - PollOption(text="A1", voter_count=2), - PollOption(text="A2", voter_count=3), - ], - is_closed=False, - is_anonymous=False, - type="quiz", - allows_multiple_answers=False, - total_voter_count=0, - correct_option_id=0, - ), - ) + response = await router._listen_update( + Update( + update_id=42, + poll=Poll( + id="poll id", + question="Q?", + options=[ + PollOption(text="A1", voter_count=2), + PollOption(text="A2", voter_count=3), + ], + is_closed=False, + is_anonymous=False, + type="quiz", + allows_multiple_answers=False, + total_voter_count=0, + correct_option_id=0, + ), ) + ) + assert response is NOT_HANDLED @pytest.mark.asyncio async def test_nested_router_listen_update(self): @@ -345,8 +360,6 @@ class TestRouter: @observer() async def my_handler(event: Message, **kwargs: Any): - assert Chat.get_current(False) - assert User.get_current(False) return kwargs update = Update( @@ -409,14 +422,6 @@ class TestRouter: await router1.emit_shutdown() assert results == [2, 1, 2] - def test_use(self): - router = Router() - - middleware = router.use(BaseMiddleware()) - assert isinstance(middleware, BaseMiddleware) - assert middleware.configured - assert middleware.manager == router.middleware - def test_skip(self): with pytest.raises(SkipHandler): skip() @@ -444,37 +449,20 @@ class TestRouter: ), ) with pytest.raises(Exception, match="KABOOM"): - await root_router.listen_update( - update_type="message", - update=update, - event=update.message, - from_user=update.message.from_user, - chat=update.message.chat, - ) + await root_router.update.trigger(update) @root_router.errors() - async def root_error_handler(exception: Exception): + async def root_error_handler(event: Update, exception: Exception): return exception - response = await root_router.listen_update( - update_type="message", - update=update, - event=update.message, - from_user=update.message.from_user, - chat=update.message.chat, - ) + response = await root_router.update.trigger(update) + assert isinstance(response, Exception) assert str(response) == "KABOOM" @router.errors() - async def error_handler(exception: Exception): + async def error_handler(event: Update, exception: Exception): return "KABOOM" - response = await root_router.listen_update( - update_type="message", - update=update, - event=update.message, - from_user=update.message.from_user, - chat=update.message.chat, - ) + response = await root_router.update.trigger(update) assert response == "KABOOM"