mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Refactor EventObserver & TelegramEventObserver
This commit is contained in:
parent
3b2df194a9
commit
9907eada32
3 changed files with 86 additions and 84 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
|
|
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
|
|||
SyncFilter = Callable[[Any], Any]
|
||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||
HandlerType = Union[CallbackType, BaseHandler]
|
||||
HandlerType = Union[FilterType, BaseHandler]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
|
|||
@dataclass
|
||||
class HandlerObject(CallableMixin):
|
||||
callback: HandlerType
|
||||
filters: List[FilterObject]
|
||||
filters: Optional[List[FilterObject]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super(HandlerObject, self).__post_init__()
|
||||
|
|
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
|
|||
self.awaitable = True
|
||||
|
||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||
if not self.filters:
|
||||
return True, {}
|
||||
for event_filter in self.filters:
|
||||
check = await event_filter.call(*args, **kwargs)
|
||||
if not check:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
|
|
@ -34,15 +35,11 @@ class EventObserver:
|
|||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
|
||||
def register(self, callback: HandlerType) -> HandlerType:
|
||||
"""
|
||||
Register callback with filters
|
||||
"""
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
|
||||
)
|
||||
)
|
||||
self.handlers.append(HandlerObject(callback=callback))
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
|
|
@ -51,22 +48,18 @@ class EventObserver:
|
|||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
kwargs_copy = copy.copy(kwargs)
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
if result:
|
||||
kwargs_copy.update(data)
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs_copy)
|
||||
except SkipHandler:
|
||||
continue
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
|
||||
def __call__(self) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback, *args)
|
||||
self.register(callback)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
|
@ -148,16 +141,29 @@ class TelegramEventObserver(EventObserver):
|
|||
Register event handler
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
return super().register(callback, *filters, *resolved_filters)
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
||||
)
|
||||
)
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||
yield result
|
||||
break
|
||||
for handler in self.handlers:
|
||||
kwargs_copy = copy.copy(kwargs)
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
if result:
|
||||
kwargs_copy.update(data)
|
||||
try:
|
||||
yield await handler.call(*args, **kwargs_copy)
|
||||
except SkipHandler:
|
||||
continue
|
||||
break
|
||||
|
||||
def __call__(
|
||||
self, *args: FilterType, **bound_filters: BaseFilter
|
||||
|
|
|
|||
|
|
@ -39,68 +39,38 @@ class MyFilter3(MyFilter1):
|
|||
|
||||
|
||||
class TestEventObserver:
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
pytest.param(5, my_handler, []),
|
||||
pytest.param(3, my_handler, [lambda event: True]),
|
||||
pytest.param(
|
||||
2,
|
||||
my_handler,
|
||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_register_filters(self, count, handler, filters):
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer.register(wrapped_handler, *filters)
|
||||
observer.register(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
pytest.param(5, my_handler, []),
|
||||
pytest.param(3, my_handler, [lambda event: True]),
|
||||
pytest.param(
|
||||
2,
|
||||
my_handler,
|
||||
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_register_filters_via_decorator(self, count, handler, filters):
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters_via_decorator(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer(*filters)(wrapped_handler)
|
||||
observer()(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_rejected(self):
|
||||
observer = EventObserver()
|
||||
observer.register(my_handler, lambda event: False)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == []
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_accepted_bool(self):
|
||||
observer = EventObserver()
|
||||
observer.register(my_handler, lambda event: True)
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42]
|
||||
|
|
@ -108,23 +78,12 @@ class TestEventObserver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_trigger_with_skip(self):
|
||||
observer = EventObserver()
|
||||
observer.register(skip_my_handler, lambda event: True)
|
||||
observer.register(my_handler, lambda event: False)
|
||||
observer.register(my_handler, lambda event: True)
|
||||
observer.register(skip_my_handler)
|
||||
observer.register(my_handler)
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_right_context_in_handlers(self):
|
||||
observer = EventObserver()
|
||||
observer.register(
|
||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [((42,), {"b": 2})]
|
||||
assert results == [42, 42]
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
|
|
@ -144,9 +103,9 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter in event_observer.filters
|
||||
|
||||
def test_resolve_filters_chain(self):
|
||||
router1 = Router()
|
||||
router2 = Router()
|
||||
router3 = Router()
|
||||
router1 = Router(use_builtin_filters=False)
|
||||
router2 = Router(use_builtin_filters=False)
|
||||
router3 = Router(use_builtin_filters=False)
|
||||
router1.include_router(router2)
|
||||
router2.include_router(router3)
|
||||
|
||||
|
|
@ -168,7 +127,7 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter3 in filters_chain3
|
||||
|
||||
def test_resolve_filters(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
|
|
@ -189,7 +148,7 @@ class TestTelegramEventObserver:
|
|||
assert observer.resolve_filters({"test": ...})
|
||||
|
||||
def test_register(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
|
|
@ -214,7 +173,7 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter1(test="PASS") in callbacks
|
||||
|
||||
def test_register_decorator(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
|
||||
@observer()
|
||||
|
|
@ -226,7 +185,7 @@ class TestTelegramEventObserver:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger(self):
|
||||
router = Router()
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.register(my_handler, test="ok")
|
||||
|
|
@ -241,3 +200,38 @@ class TestTelegramEventObserver:
|
|||
|
||||
results = [result async for result in observer.trigger(message)]
|
||||
assert results == [message]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
(
|
||||
[5, my_handler, []],
|
||||
[3, my_handler, [lambda event: True]],
|
||||
[2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]],
|
||||
),
|
||||
)
|
||||
def test_register_filters_via_decorator(self, count, handler, filters):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer(*filters)(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
|
||||
#
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_right_context_in_handlers(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message_handler
|
||||
observer.register(
|
||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [((42,), {"b": 2})]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue