From 9238533e93404080db3b3d1dd019f51f5023317d Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Tue, 17 Aug 2021 00:43:27 +0300 Subject: [PATCH] Feature/rework middlewares chain (#664) * Reworked middlewares chain * Added description for router name * Added patch-notes * Fixed type hints --- CHANGES/664.feature | 1 + aiogram/dispatcher/dispatcher.py | 3 +- aiogram/dispatcher/event/telegram.py | 22 ++++++----- aiogram/dispatcher/router.py | 26 ++++++++++++- poetry.lock | 47 ++++++++++++++++++++---- pyproject.toml | 1 + tests/test_dispatcher/test_dispatcher.py | 23 +++++++----- 7 files changed, 94 insertions(+), 29 deletions(-) create mode 100644 CHANGES/664.feature diff --git a/CHANGES/664.feature b/CHANGES/664.feature new file mode 100644 index 00000000..2db72144 --- /dev/null +++ b/CHANGES/664.feature @@ -0,0 +1 @@ +Reworked outer middleware chain. Prevent to call many times the outer middleware for each nested router diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index ee61d1d4..2f4bb1ba 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -98,7 +98,8 @@ class Dispatcher(Router): token = Bot.set_current(bot) try: - response = await self.update.trigger(update, bot=bot, **kwargs) + kwargs.update(bot=bot) + response = await self.update.wrap_outer_middleware(self.update.trigger, update, kwargs) handled = response is not UNHANDLED return response finally: diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index ad03c06c..424ffeb3 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -84,13 +84,13 @@ class TelegramEventObserver: :param *: """ middlewares = [] - - for router in reversed(list(self.router.chain_head)): - observer = router.observers[self.event_name] - if outer: - middlewares.extend(observer.outer_middlewares) - else: + if outer: + middlewares.extend(self.outer_middlewares) + else: + for router in reversed(list(self.router.chain_head)): + observer = router.observers[self.event_name] middlewares.extend(observer.middlewares) + return middlewares def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: @@ -148,15 +148,17 @@ class TelegramEventObserver: middleware = functools.partial(m, middleware) return middleware + def wrap_outer_middleware( + self, callback: Any, event: TelegramObject, data: Dict[str, Any] + ) -> Any: + wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback) + return wrapped_outer(event, data) + async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any: """ Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger) - return await wrapped_outer(event, kwargs) - - async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: # Check globally defined filters before any other handler will be checked result, data = await self._handler.check(event, **kwargs) if not result: diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index bc66d0de..b776bcdf 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -21,16 +21,17 @@ class Router: - By observer method - :obj:`router..register(handler, )` - By decorator - :obj:`@router.()` - """ - def __init__(self, use_builtin_filters: bool = True) -> None: + def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None: """ :param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory + :param name: Optional router name, can be useful for debugging """ self.use_builtin_filters = use_builtin_filters + self.name = name or hex(id(self)) self._parent_router: Optional[Router] = None self.sub_routers: List[Router] = [] @@ -84,9 +85,30 @@ class Router: for builtin_filter in BUILTIN_FILTERS.get(name, ()): observer.bind_filter(builtin_filter) + def __str__(self) -> str: + return f"{type(self).__name__} {self.name!r}" + + def __repr__(self) -> str: + return f"<{self}>" + async def propagate_event(self, update_type: str, event: TelegramObject, **kwargs: Any) -> Any: kwargs.update(event_router=self) observer = self.observers[update_type] + + async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any: + return await self._propagate_event( + observer=observer, update_type=update_type, event=telegram_event, **data + ) + + return await observer.wrap_outer_middleware(_wrapped, event=event, data=kwargs) + + async def _propagate_event( + self, + observer: TelegramEventObserver, + update_type: str, + event: TelegramObject, + **kwargs: Any, + ) -> Any: response = await observer.trigger(event, **kwargs) if response is REJECTED: return UNHANDLED diff --git a/poetry.lock b/poetry.lock index 9b60c280..7f29ff0d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -256,6 +256,18 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "diagrams" +version = "0.20.0" +description = "Diagram as Code" +category = "dev" +optional = false +python-versions = ">=3.6,<4.0" + +[package.dependencies] +graphviz = ">=0.13.2,<0.17.0" +jinja2 = ">=2.10,<3.0" + [[package]] name = "distlib" version = "0.3.2" @@ -323,6 +335,19 @@ sphinx = ">=3.0,<5.0" doc = ["myst-parser", "sphinx-copybutton", "sphinx-inline-tabs", "docutils (!=0.17)"] test = ["pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "graphviz" +version = "0.16" +description = "Simple Python interface for Graphviz" +category = "dev" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" + +[package.extras] +dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"] +docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"] +test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"] + [[package]] name = "identify" version = "2.2.10" @@ -454,17 +479,17 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<6.0.0)"] [[package]] name = "jinja2" -version = "3.0.1" +version = "2.11.3" description = "A very fast and expressive template engine." category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.dependencies] -MarkupSafe = ">=2.0" +MarkupSafe = ">=0.23" [package.extras] -i18n = ["Babel (>=2.7)"] +i18n = ["Babel (>=0.8)"] [[package]] name = "livereload" @@ -1241,7 +1266,7 @@ redis = ["aioredis"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "f6ac17a44b1eec95b101daab369097785a093d9263d0c6cf6c9ef8d363d8962d" +content-hash = "e8bc158e14347b3766672505f38ad9d76b1cbf6f9557565e4e56664b0f663717" [metadata.files] aiofiles = [ @@ -1424,6 +1449,10 @@ decorator = [ {file = "decorator-5.0.9-py3-none-any.whl", hash = "sha256:6e5c199c16f7a9f0e3a61a4a54b3d27e7dad0dbdde92b944426cb20914376323"}, {file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"}, ] +diagrams = [ + {file = "diagrams-0.20.0-py3-none-any.whl", hash = "sha256:395391663b4d3f2d3e3614797402ca99494e00baf3926f5c9e72856d34cafedd"}, + {file = "diagrams-0.20.0.tar.gz", hash = "sha256:a50743ed9274e194e7898820f69aa12868ae217003580ef9e7d0285132c9674a"}, +] distlib = [ {file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"}, {file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"}, @@ -1448,6 +1477,10 @@ furo = [ {file = "furo-2021.6.18b36-py3-none-any.whl", hash = "sha256:a4c00634afeb5896a34d141a5dffb62f20c5eca7831b78269823a8cd8b09a5e4"}, {file = "furo-2021.6.18b36.tar.gz", hash = "sha256:46a30bc597a9067088d39d730e7d9bf6c1a1d71967e4af062f796769f66b3bdb"}, ] +graphviz = [ + {file = "graphviz-0.16-py2.py3-none-any.whl", hash = "sha256:3cad5517c961090dfc679df6402a57de62d97703e2880a1a46147bb0dc1639eb"}, + {file = "graphviz-0.16.zip", hash = "sha256:d2d25af1c199cad567ce4806f0449cb74eb30cf451fd7597251e1da099ac6e57"}, +] identify = [ {file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"}, {file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"}, @@ -1489,8 +1522,8 @@ jedi = [ {file = "jedi-0.18.0.tar.gz", hash = "sha256:92550a404bad8afed881a137ec9a461fed49eca661414be45059329614ed0707"}, ] jinja2 = [ - {file = "Jinja2-3.0.1-py3-none-any.whl", hash = "sha256:1f06f2da51e7b56b8f238affdd6b4e2c61e39598a378cc49345bc1bd42a978a4"}, - {file = "Jinja2-3.0.1.tar.gz", hash = "sha256:703f484b47a6af502e743c9122595cc812b0271f661722403114f71a79d0f5a4"}, + {file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"}, + {file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"}, ] livereload = [ {file = "livereload-2.6.3.tar.gz", hash = "sha256:776f2f865e59fde56490a56bcc6773b6917366bce0c267c60ee8aaf1a0959869"}, diff --git a/pyproject.toml b/pyproject.toml index 47fb8bf3..6b659613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ furo = "^2021.6.18-beta.36" sphinx-prompt = "^1.3.0" Sphinx-Substitution-Extensions = "^2020.9.30" towncrier = "^21.3.0" +diagrams = "^0.20.0" [tool.poetry.extras] diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 1c97cf3b..520b190c 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -422,7 +422,12 @@ class TestDispatcher: ], ) async def test_listen_update( - self, event_type: str, update: Update, has_chat: bool, has_user: bool + self, + event_type: str, + update: Update, + has_chat: bool, + has_user: bool, + bot: MockedBot, ): router = Dispatcher() observer = router.observers[event_type] @@ -436,7 +441,7 @@ class TestDispatcher: assert User.get_current(False) return kwargs - result = await router.update.trigger(update, test="PASS", bot=None) + result = await router.feed_update(bot, update, test="PASS") assert isinstance(result, dict) assert result["event_update"] == update assert result["event_router"] == router @@ -477,7 +482,7 @@ class TestDispatcher: ) assert response is UNHANDLED - async def test_nested_router_listen_update(self): + async def test_nested_router_listen_update(self, bot: MockedBot): dp = Dispatcher() router0 = Router() router1 = Router() @@ -499,7 +504,7 @@ class TestDispatcher: from_user=User(id=42, is_bot=False, first_name="Test"), ), ) - result = await dp._listen_update(update, test="PASS") + result = await dp.feed_update(bot, update, test="PASS") assert isinstance(result, dict) assert result["event_update"] == update assert result["event_router"] == router1 @@ -542,7 +547,7 @@ class TestDispatcher: baz=..., ) - assert counter["root.outer_middleware"] == 2 + assert counter["root.outer_middleware"] == 1 assert counter["root.middleware"] == 1 assert counter["child.outer_middleware"] == 1 assert counter["child.middleware"] == 1 @@ -596,7 +601,7 @@ class TestDispatcher: else: mocked_process_update.assert_awaited() - async def test_exception_handler_catch_exceptions(self): + async def test_exception_handler_catch_exceptions(self, bot: MockedBot): dp = Dispatcher() router = Router() dp.include_router(router) @@ -619,20 +624,20 @@ class TestDispatcher: ), ) with pytest.raises(CustomException, match="KABOOM"): - await dp.update.trigger(update, bot=None) + await dp.feed_update(bot, update) @router.errors() async def error_handler(event: Update, exception: Exception): return "KABOOM" - response = await dp.update.trigger(update, bot=None) + response = await dp.feed_update(bot, update) assert response == "KABOOM" @dp.errors() async def root_error_handler(event: Update, exception: Exception): return exception - response = await dp.update.trigger(update, bot=None) + response = await dp.feed_update(bot, update) assert isinstance(response, CustomException) assert str(response) == "KABOOM"