Feature/rework middlewares chain (#664)

* Reworked middlewares chain

* Added description for router name

* Added patch-notes

* Fixed type hints
This commit is contained in:
Alex Root Junior 2021-08-17 00:43:27 +03:00 committed by GitHub
parent c1f605c6f5
commit 9238533e93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 94 additions and 29 deletions

1
CHANGES/664.feature Normal file
View file

@ -0,0 +1 @@
Reworked outer middleware chain. Prevent to call many times the outer middleware for each nested router

View file

@ -98,7 +98,8 @@ class Dispatcher(Router):
token = Bot.set_current(bot)
try:
response = await self.update.trigger(update, bot=bot, **kwargs)
kwargs.update(bot=bot)
response = await self.update.wrap_outer_middleware(self.update.trigger, update, kwargs)
handled = response is not UNHANDLED
return response
finally:

View file

@ -84,13 +84,13 @@ class TelegramEventObserver:
:param *:
"""
middlewares = []
for router in reversed(list(self.router.chain_head)):
observer = router.observers[self.event_name]
if outer:
middlewares.extend(observer.outer_middlewares)
else:
if outer:
middlewares.extend(self.outer_middlewares)
else:
for router in reversed(list(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares)
return middlewares
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
@ -148,15 +148,17 @@ class TelegramEventObserver:
middleware = functools.partial(m, middleware)
return middleware
def wrap_outer_middleware(
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
) -> Any:
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
return wrapped_outer(event, data)
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
"""
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger)
return await wrapped_outer(event, kwargs)
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
# Check globally defined filters before any other handler will be checked
result, data = await self._handler.check(event, **kwargs)
if not result:

View file

@ -21,16 +21,17 @@ class Router:
- By observer method - :obj:`router.<event_type>.register(handler, <filters, ...>)`
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
"""
def __init__(self, use_builtin_filters: bool = True) -> None:
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
"""
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
:param name: Optional router name, can be useful for debugging
"""
self.use_builtin_filters = use_builtin_filters
self.name = name or hex(id(self))
self._parent_router: Optional[Router] = None
self.sub_routers: List[Router] = []
@ -84,9 +85,30 @@ class Router:
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)
def __str__(self) -> str:
return f"{type(self).__name__} {self.name!r}"
def __repr__(self) -> str:
return f"<{self}>"
async def propagate_event(self, update_type: str, event: TelegramObject, **kwargs: Any) -> Any:
kwargs.update(event_router=self)
observer = self.observers[update_type]
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
return await self._propagate_event(
observer=observer, update_type=update_type, event=telegram_event, **data
)
return await observer.wrap_outer_middleware(_wrapped, event=event, data=kwargs)
async def _propagate_event(
self,
observer: TelegramEventObserver,
update_type: str,
event: TelegramObject,
**kwargs: Any,
) -> Any:
response = await observer.trigger(event, **kwargs)
if response is REJECTED:
return UNHANDLED

47
poetry.lock generated
View file

@ -256,6 +256,18 @@ category = "dev"
optional = false
python-versions = ">=3.5"
[[package]]
name = "diagrams"
version = "0.20.0"
description = "Diagram as Code"
category = "dev"
optional = false
python-versions = ">=3.6,<4.0"
[package.dependencies]
graphviz = ">=0.13.2,<0.17.0"
jinja2 = ">=2.10,<3.0"
[[package]]
name = "distlib"
version = "0.3.2"
@ -323,6 +335,19 @@ sphinx = ">=3.0,<5.0"
doc = ["myst-parser", "sphinx-copybutton", "sphinx-inline-tabs", "docutils (!=0.17)"]
test = ["pytest", "pytest-cov", "pytest-xdist"]
[[package]]
name = "graphviz"
version = "0.16"
description = "Simple Python interface for Graphviz"
category = "dev"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
[package.extras]
dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"]
docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"]
test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"]
[[package]]
name = "identify"
version = "2.2.10"
@ -454,17 +479,17 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<6.0.0)"]
[[package]]
name = "jinja2"
version = "3.0.1"
version = "2.11.3"
description = "A very fast and expressive template engine."
category = "main"
optional = false
python-versions = ">=3.6"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.dependencies]
MarkupSafe = ">=2.0"
MarkupSafe = ">=0.23"
[package.extras]
i18n = ["Babel (>=2.7)"]
i18n = ["Babel (>=0.8)"]
[[package]]
name = "livereload"
@ -1241,7 +1266,7 @@ redis = ["aioredis"]
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "f6ac17a44b1eec95b101daab369097785a093d9263d0c6cf6c9ef8d363d8962d"
content-hash = "e8bc158e14347b3766672505f38ad9d76b1cbf6f9557565e4e56664b0f663717"
[metadata.files]
aiofiles = [
@ -1424,6 +1449,10 @@ decorator = [
{file = "decorator-5.0.9-py3-none-any.whl", hash = "sha256:6e5c199c16f7a9f0e3a61a4a54b3d27e7dad0dbdde92b944426cb20914376323"},
{file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"},
]
diagrams = [
{file = "diagrams-0.20.0-py3-none-any.whl", hash = "sha256:395391663b4d3f2d3e3614797402ca99494e00baf3926f5c9e72856d34cafedd"},
{file = "diagrams-0.20.0.tar.gz", hash = "sha256:a50743ed9274e194e7898820f69aa12868ae217003580ef9e7d0285132c9674a"},
]
distlib = [
{file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"},
{file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"},
@ -1448,6 +1477,10 @@ furo = [
{file = "furo-2021.6.18b36-py3-none-any.whl", hash = "sha256:a4c00634afeb5896a34d141a5dffb62f20c5eca7831b78269823a8cd8b09a5e4"},
{file = "furo-2021.6.18b36.tar.gz", hash = "sha256:46a30bc597a9067088d39d730e7d9bf6c1a1d71967e4af062f796769f66b3bdb"},
]
graphviz = [
{file = "graphviz-0.16-py2.py3-none-any.whl", hash = "sha256:3cad5517c961090dfc679df6402a57de62d97703e2880a1a46147bb0dc1639eb"},
{file = "graphviz-0.16.zip", hash = "sha256:d2d25af1c199cad567ce4806f0449cb74eb30cf451fd7597251e1da099ac6e57"},
]
identify = [
{file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"},
{file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"},
@ -1489,8 +1522,8 @@ jedi = [
{file = "jedi-0.18.0.tar.gz", hash = "sha256:92550a404bad8afed881a137ec9a461fed49eca661414be45059329614ed0707"},
]
jinja2 = [
{file = "Jinja2-3.0.1-py3-none-any.whl", hash = "sha256:1f06f2da51e7b56b8f238affdd6b4e2c61e39598a378cc49345bc1bd42a978a4"},
{file = "Jinja2-3.0.1.tar.gz", hash = "sha256:703f484b47a6af502e743c9122595cc812b0271f661722403114f71a79d0f5a4"},
{file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"},
{file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"},
]
livereload = [
{file = "livereload-2.6.3.tar.gz", hash = "sha256:776f2f865e59fde56490a56bcc6773b6917366bce0c267c60ee8aaf1a0959869"},

View file

@ -82,6 +82,7 @@ furo = "^2021.6.18-beta.36"
sphinx-prompt = "^1.3.0"
Sphinx-Substitution-Extensions = "^2020.9.30"
towncrier = "^21.3.0"
diagrams = "^0.20.0"
[tool.poetry.extras]

View file

@ -422,7 +422,12 @@ class TestDispatcher:
],
)
async def test_listen_update(
self, event_type: str, update: Update, has_chat: bool, has_user: bool
self,
event_type: str,
update: Update,
has_chat: bool,
has_user: bool,
bot: MockedBot,
):
router = Dispatcher()
observer = router.observers[event_type]
@ -436,7 +441,7 @@ class TestDispatcher:
assert User.get_current(False)
return kwargs
result = await router.update.trigger(update, test="PASS", bot=None)
result = await router.feed_update(bot, update, test="PASS")
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@ -477,7 +482,7 @@ class TestDispatcher:
)
assert response is UNHANDLED
async def test_nested_router_listen_update(self):
async def test_nested_router_listen_update(self, bot: MockedBot):
dp = Dispatcher()
router0 = Router()
router1 = Router()
@ -499,7 +504,7 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
)
result = await dp._listen_update(update, test="PASS")
result = await dp.feed_update(bot, update, test="PASS")
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router1
@ -542,7 +547,7 @@ class TestDispatcher:
baz=...,
)
assert counter["root.outer_middleware"] == 2
assert counter["root.outer_middleware"] == 1
assert counter["root.middleware"] == 1
assert counter["child.outer_middleware"] == 1
assert counter["child.middleware"] == 1
@ -596,7 +601,7 @@ class TestDispatcher:
else:
mocked_process_update.assert_awaited()
async def test_exception_handler_catch_exceptions(self):
async def test_exception_handler_catch_exceptions(self, bot: MockedBot):
dp = Dispatcher()
router = Router()
dp.include_router(router)
@ -619,20 +624,20 @@ class TestDispatcher:
),
)
with pytest.raises(CustomException, match="KABOOM"):
await dp.update.trigger(update, bot=None)
await dp.feed_update(bot, update)
@router.errors()
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
response = await dp.update.trigger(update, bot=None)
response = await dp.feed_update(bot, update)
assert response == "KABOOM"
@dp.errors()
async def root_error_handler(event: Update, exception: Exception):
return exception
response = await dp.update.trigger(update, bot=None)
response = await dp.feed_update(bot, update)
assert isinstance(response, CustomException)
assert str(response) == "KABOOM"