mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 07:50:32 +00:00
Feature/rework middlewares chain (#664)
* Reworked middlewares chain * Added description for router name * Added patch-notes * Fixed type hints
This commit is contained in:
parent
c1f605c6f5
commit
9238533e93
7 changed files with 94 additions and 29 deletions
1
CHANGES/664.feature
Normal file
1
CHANGES/664.feature
Normal file
|
|
@ -0,0 +1 @@
|
|||
Reworked outer middleware chain. Prevent to call many times the outer middleware for each nested router
|
||||
|
|
@ -98,7 +98,8 @@ class Dispatcher(Router):
|
|||
|
||||
token = Bot.set_current(bot)
|
||||
try:
|
||||
response = await self.update.trigger(update, bot=bot, **kwargs)
|
||||
kwargs.update(bot=bot)
|
||||
response = await self.update.wrap_outer_middleware(self.update.trigger, update, kwargs)
|
||||
handled = response is not UNHANDLED
|
||||
return response
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -84,13 +84,13 @@ class TelegramEventObserver:
|
|||
:param *:
|
||||
"""
|
||||
middlewares = []
|
||||
|
||||
for router in reversed(list(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
if outer:
|
||||
middlewares.extend(observer.outer_middlewares)
|
||||
else:
|
||||
if outer:
|
||||
middlewares.extend(self.outer_middlewares)
|
||||
else:
|
||||
for router in reversed(list(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
|
|
@ -148,15 +148,17 @@ class TelegramEventObserver:
|
|||
middleware = functools.partial(m, middleware)
|
||||
return middleware
|
||||
|
||||
def wrap_outer_middleware(
|
||||
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
|
||||
return wrapped_outer(event, data)
|
||||
|
||||
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger)
|
||||
return await wrapped_outer(event, kwargs)
|
||||
|
||||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
# Check globally defined filters before any other handler will be checked
|
||||
result, data = await self._handler.check(event, **kwargs)
|
||||
if not result:
|
||||
|
|
|
|||
|
|
@ -21,16 +21,17 @@ class Router:
|
|||
|
||||
- By observer method - :obj:`router.<event_type>.register(handler, <filters, ...>)`
|
||||
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, use_builtin_filters: bool = True) -> None:
|
||||
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
|
||||
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
|
||||
:param name: Optional router name, can be useful for debugging
|
||||
"""
|
||||
|
||||
self.use_builtin_filters = use_builtin_filters
|
||||
self.name = name or hex(id(self))
|
||||
|
||||
self._parent_router: Optional[Router] = None
|
||||
self.sub_routers: List[Router] = []
|
||||
|
|
@ -84,9 +85,30 @@ class Router:
|
|||
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
|
||||
observer.bind_filter(builtin_filter)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__} {self.name!r}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self}>"
|
||||
|
||||
async def propagate_event(self, update_type: str, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
kwargs.update(event_router=self)
|
||||
observer = self.observers[update_type]
|
||||
|
||||
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
|
||||
return await self._propagate_event(
|
||||
observer=observer, update_type=update_type, event=telegram_event, **data
|
||||
)
|
||||
|
||||
return await observer.wrap_outer_middleware(_wrapped, event=event, data=kwargs)
|
||||
|
||||
async def _propagate_event(
|
||||
self,
|
||||
observer: TelegramEventObserver,
|
||||
update_type: str,
|
||||
event: TelegramObject,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
response = await observer.trigger(event, **kwargs)
|
||||
if response is REJECTED:
|
||||
return UNHANDLED
|
||||
|
|
|
|||
47
poetry.lock
generated
47
poetry.lock
generated
|
|
@ -256,6 +256,18 @@ category = "dev"
|
|||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[[package]]
|
||||
name = "diagrams"
|
||||
version = "0.20.0"
|
||||
description = "Diagram as Code"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6,<4.0"
|
||||
|
||||
[package.dependencies]
|
||||
graphviz = ">=0.13.2,<0.17.0"
|
||||
jinja2 = ">=2.10,<3.0"
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.3.2"
|
||||
|
|
@ -323,6 +335,19 @@ sphinx = ">=3.0,<5.0"
|
|||
doc = ["myst-parser", "sphinx-copybutton", "sphinx-inline-tabs", "docutils (!=0.17)"]
|
||||
test = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||
|
||||
[[package]]
|
||||
name = "graphviz"
|
||||
version = "0.16"
|
||||
description = "Simple Python interface for Graphviz"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"]
|
||||
docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"]
|
||||
test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.2.10"
|
||||
|
|
@ -454,17 +479,17 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<6.0.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.0.1"
|
||||
version = "2.11.3"
|
||||
description = "A very fast and expressive template engine."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[package.dependencies]
|
||||
MarkupSafe = ">=2.0"
|
||||
MarkupSafe = ">=0.23"
|
||||
|
||||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
i18n = ["Babel (>=0.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "livereload"
|
||||
|
|
@ -1241,7 +1266,7 @@ redis = ["aioredis"]
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "f6ac17a44b1eec95b101daab369097785a093d9263d0c6cf6c9ef8d363d8962d"
|
||||
content-hash = "e8bc158e14347b3766672505f38ad9d76b1cbf6f9557565e4e56664b0f663717"
|
||||
|
||||
[metadata.files]
|
||||
aiofiles = [
|
||||
|
|
@ -1424,6 +1449,10 @@ decorator = [
|
|||
{file = "decorator-5.0.9-py3-none-any.whl", hash = "sha256:6e5c199c16f7a9f0e3a61a4a54b3d27e7dad0dbdde92b944426cb20914376323"},
|
||||
{file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"},
|
||||
]
|
||||
diagrams = [
|
||||
{file = "diagrams-0.20.0-py3-none-any.whl", hash = "sha256:395391663b4d3f2d3e3614797402ca99494e00baf3926f5c9e72856d34cafedd"},
|
||||
{file = "diagrams-0.20.0.tar.gz", hash = "sha256:a50743ed9274e194e7898820f69aa12868ae217003580ef9e7d0285132c9674a"},
|
||||
]
|
||||
distlib = [
|
||||
{file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"},
|
||||
{file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"},
|
||||
|
|
@ -1448,6 +1477,10 @@ furo = [
|
|||
{file = "furo-2021.6.18b36-py3-none-any.whl", hash = "sha256:a4c00634afeb5896a34d141a5dffb62f20c5eca7831b78269823a8cd8b09a5e4"},
|
||||
{file = "furo-2021.6.18b36.tar.gz", hash = "sha256:46a30bc597a9067088d39d730e7d9bf6c1a1d71967e4af062f796769f66b3bdb"},
|
||||
]
|
||||
graphviz = [
|
||||
{file = "graphviz-0.16-py2.py3-none-any.whl", hash = "sha256:3cad5517c961090dfc679df6402a57de62d97703e2880a1a46147bb0dc1639eb"},
|
||||
{file = "graphviz-0.16.zip", hash = "sha256:d2d25af1c199cad567ce4806f0449cb74eb30cf451fd7597251e1da099ac6e57"},
|
||||
]
|
||||
identify = [
|
||||
{file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"},
|
||||
{file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"},
|
||||
|
|
@ -1489,8 +1522,8 @@ jedi = [
|
|||
{file = "jedi-0.18.0.tar.gz", hash = "sha256:92550a404bad8afed881a137ec9a461fed49eca661414be45059329614ed0707"},
|
||||
]
|
||||
jinja2 = [
|
||||
{file = "Jinja2-3.0.1-py3-none-any.whl", hash = "sha256:1f06f2da51e7b56b8f238affdd6b4e2c61e39598a378cc49345bc1bd42a978a4"},
|
||||
{file = "Jinja2-3.0.1.tar.gz", hash = "sha256:703f484b47a6af502e743c9122595cc812b0271f661722403114f71a79d0f5a4"},
|
||||
{file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"},
|
||||
{file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"},
|
||||
]
|
||||
livereload = [
|
||||
{file = "livereload-2.6.3.tar.gz", hash = "sha256:776f2f865e59fde56490a56bcc6773b6917366bce0c267c60ee8aaf1a0959869"},
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ furo = "^2021.6.18-beta.36"
|
|||
sphinx-prompt = "^1.3.0"
|
||||
Sphinx-Substitution-Extensions = "^2020.9.30"
|
||||
towncrier = "^21.3.0"
|
||||
diagrams = "^0.20.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
|||
|
|
@ -422,7 +422,12 @@ class TestDispatcher:
|
|||
],
|
||||
)
|
||||
async def test_listen_update(
|
||||
self, event_type: str, update: Update, has_chat: bool, has_user: bool
|
||||
self,
|
||||
event_type: str,
|
||||
update: Update,
|
||||
has_chat: bool,
|
||||
has_user: bool,
|
||||
bot: MockedBot,
|
||||
):
|
||||
router = Dispatcher()
|
||||
observer = router.observers[event_type]
|
||||
|
|
@ -436,7 +441,7 @@ class TestDispatcher:
|
|||
assert User.get_current(False)
|
||||
return kwargs
|
||||
|
||||
result = await router.update.trigger(update, test="PASS", bot=None)
|
||||
result = await router.feed_update(bot, update, test="PASS")
|
||||
assert isinstance(result, dict)
|
||||
assert result["event_update"] == update
|
||||
assert result["event_router"] == router
|
||||
|
|
@ -477,7 +482,7 @@ class TestDispatcher:
|
|||
)
|
||||
assert response is UNHANDLED
|
||||
|
||||
async def test_nested_router_listen_update(self):
|
||||
async def test_nested_router_listen_update(self, bot: MockedBot):
|
||||
dp = Dispatcher()
|
||||
router0 = Router()
|
||||
router1 = Router()
|
||||
|
|
@ -499,7 +504,7 @@ class TestDispatcher:
|
|||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
)
|
||||
result = await dp._listen_update(update, test="PASS")
|
||||
result = await dp.feed_update(bot, update, test="PASS")
|
||||
assert isinstance(result, dict)
|
||||
assert result["event_update"] == update
|
||||
assert result["event_router"] == router1
|
||||
|
|
@ -542,7 +547,7 @@ class TestDispatcher:
|
|||
baz=...,
|
||||
)
|
||||
|
||||
assert counter["root.outer_middleware"] == 2
|
||||
assert counter["root.outer_middleware"] == 1
|
||||
assert counter["root.middleware"] == 1
|
||||
assert counter["child.outer_middleware"] == 1
|
||||
assert counter["child.middleware"] == 1
|
||||
|
|
@ -596,7 +601,7 @@ class TestDispatcher:
|
|||
else:
|
||||
mocked_process_update.assert_awaited()
|
||||
|
||||
async def test_exception_handler_catch_exceptions(self):
|
||||
async def test_exception_handler_catch_exceptions(self, bot: MockedBot):
|
||||
dp = Dispatcher()
|
||||
router = Router()
|
||||
dp.include_router(router)
|
||||
|
|
@ -619,20 +624,20 @@ class TestDispatcher:
|
|||
),
|
||||
)
|
||||
with pytest.raises(CustomException, match="KABOOM"):
|
||||
await dp.update.trigger(update, bot=None)
|
||||
await dp.feed_update(bot, update)
|
||||
|
||||
@router.errors()
|
||||
async def error_handler(event: Update, exception: Exception):
|
||||
return "KABOOM"
|
||||
|
||||
response = await dp.update.trigger(update, bot=None)
|
||||
response = await dp.feed_update(bot, update)
|
||||
assert response == "KABOOM"
|
||||
|
||||
@dp.errors()
|
||||
async def root_error_handler(event: Update, exception: Exception):
|
||||
return exception
|
||||
|
||||
response = await dp.update.trigger(update, bot=None)
|
||||
response = await dp.feed_update(bot, update)
|
||||
|
||||
assert isinstance(response, CustomException)
|
||||
assert str(response) == "KABOOM"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue