Refactor EventObserver & TelegramEventObserver

This commit is contained in:
Alex Root Junior 2020-01-13 21:17:28 +02:00
parent 3b2df194a9
commit 9907eada32
3 changed files with 86 additions and 84 deletions

View file

@ -1,7 +1,7 @@
import inspect
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
@ -10,7 +10,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[CallbackType, BaseHandler]
HandlerType = Union[FilterType, BaseHandler]
@dataclass
@ -47,7 +47,7 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
callback: HandlerType
filters: List[FilterObject]
filters: Optional[List[FilterObject]] = None
def __post_init__(self):
super(HandlerObject, self).__post_init__()
@ -56,6 +56,8 @@ class HandlerObject(CallableMixin):
self.awaitable = True
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters:
return True, {}
for event_filter in self.filters:
check = await event_filter.call(*args, **kwargs)
if not check:

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import copy
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
@ -34,15 +35,11 @@ class EventObserver:
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
def register(self, callback: HandlerType) -> HandlerType:
"""
Register callback with filters
"""
self.handlers.append(
HandlerObject(
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
)
)
self.handlers.append(HandlerObject(callback=callback))
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
@ -51,22 +48,18 @@ class EventObserver:
Handler will be called when all its filters is pass.
"""
for handler in self.handlers:
kwargs_copy = copy.copy(kwargs)
result, data = await handler.check(*args, **kwargs)
if result:
kwargs_copy.update(data)
try:
yield await handler.call(*args, **kwargs_copy)
except SkipHandler:
continue
try:
yield await handler.call(*args, **kwargs)
except SkipHandler:
continue
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
def __call__(self) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback, *args)
self.register(callback)
return callback
return wrapper
@ -148,16 +141,29 @@ class TelegramEventObserver(EventObserver):
Register event handler
"""
resolved_filters = self.resolve_filters(bound_filters)
return super().register(callback, *filters, *resolved_filters)
self.handlers.append(
HandlerObject(
callback=callback,
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
)
)
return callback
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
"""
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
yield result
break
for handler in self.handlers:
kwargs_copy = copy.copy(kwargs)
result, data = await handler.check(*args, **kwargs)
if result:
kwargs_copy.update(data)
try:
yield await handler.call(*args, **kwargs_copy)
except SkipHandler:
continue
break
def __call__(
self, *args: FilterType, **bound_filters: BaseFilter

View file

@ -39,68 +39,38 @@ class MyFilter3(MyFilter1):
class TestEventObserver:
@pytest.mark.parametrize(
"count,handler,filters",
(
pytest.param(5, my_handler, []),
pytest.param(3, my_handler, [lambda event: True]),
pytest.param(
2,
my_handler,
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
),
),
)
def test_register_filters(self, count, handler, filters):
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer.register(wrapped_handler, *filters)
observer.register(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
assert not registered_handler.filters
@pytest.mark.parametrize(
"count,handler,filters",
(
pytest.param(5, my_handler, []),
pytest.param(3, my_handler, [lambda event: True]),
pytest.param(
2,
my_handler,
[lambda event: True, lambda event: False, lambda event: {"ok": True}],
),
),
)
def test_register_filters_via_decorator(self, count, handler, filters):
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters_via_decorator(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer(*filters)(wrapped_handler)
observer()(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
@pytest.mark.asyncio
async def test_trigger_rejected(self):
observer = EventObserver()
observer.register(my_handler, lambda event: False)
results = [result async for result in observer.trigger(42)]
assert results == []
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger_accepted_bool(self):
observer = EventObserver()
observer.register(my_handler, lambda event: True)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@ -108,23 +78,12 @@ class TestEventObserver:
@pytest.mark.asyncio
async def test_trigger_with_skip(self):
observer = EventObserver()
observer.register(skip_my_handler, lambda event: True)
observer.register(my_handler, lambda event: False)
observer.register(my_handler, lambda event: True)
observer.register(skip_my_handler)
observer.register(my_handler)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
observer = EventObserver()
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]
assert results == [42, 42]
class TestTelegramEventObserver:
@ -144,9 +103,9 @@ class TestTelegramEventObserver:
assert MyFilter in event_observer.filters
def test_resolve_filters_chain(self):
router1 = Router()
router2 = Router()
router3 = Router()
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router3 = Router(use_builtin_filters=False)
router1.include_router(router2)
router2.include_router(router3)
@ -168,7 +127,7 @@ class TestTelegramEventObserver:
assert MyFilter3 in filters_chain3
def test_resolve_filters(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
@ -189,7 +148,7 @@ class TestTelegramEventObserver:
assert observer.resolve_filters({"test": ...})
def test_register(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
@ -214,7 +173,7 @@ class TestTelegramEventObserver:
assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
@observer()
@ -226,7 +185,7 @@ class TestTelegramEventObserver:
@pytest.mark.asyncio
async def test_trigger(self):
router = Router()
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.bind_filter(MyFilter1)
observer.register(my_handler, test="ok")
@ -241,3 +200,38 @@ class TestTelegramEventObserver:
results = [result async for result in observer.trigger(message)]
assert results == [message]
@pytest.mark.parametrize(
"count,handler,filters",
(
[5, my_handler, []],
[3, my_handler, [lambda event: True]],
[2, my_handler, [lambda event: True, lambda event: False, lambda event: {"ok": True}]],
),
)
def test_register_filters_via_decorator(self, count, handler, filters):
router = Router(use_builtin_filters=False)
observer = router.message_handler
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer(*filters)(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
#
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False)
observer = router.message_handler
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]