From 23dbd88487318eba5e2ba8687170a6f5c1359f9f Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 25 Nov 2019 23:21:01 +0200 Subject: [PATCH] Full coverage of observers --- aiogram/dispatcher/event/observer.py | 1 + .../test_event/test_observer.py | 92 +++++++++++++++---- 2 files changed, 77 insertions(+), 16 deletions(-) diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index cf9d00aa..f40a8014 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -37,6 +37,7 @@ class EventObserver: callback=callback, filters=[FilterObject(filter_) for filter_ in filters] ) ) + return callback async def trigger(self, *args, **kwargs): for handler in self.handlers: diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index 008c0ceb..ec563d9d 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -21,6 +21,21 @@ async def pipe_handler(*args, **kwargs): return args, kwargs +class MyFilter1(BaseFilter): + test: str + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + +class MyFilter2(MyFilter1): + pass + + +class MyFilter3(MyFilter1): + pass + + class TestEventObserver: @pytest.mark.parametrize( "count,handler,filters", @@ -133,18 +148,6 @@ class TestTelegramEventObserver: router1.include_router(router2) router2.include_router(router3) - class MyFilter1(BaseFilter): - test: str - - async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: - return True - - class MyFilter2(MyFilter1): - pass - - class MyFilter3(MyFilter1): - pass - router1.message_handler.bind_filter(MyFilter1) router1.message_handler.bind_filter(MyFilter2) router2.message_handler.bind_filter(MyFilter2) @@ -159,14 +162,71 @@ class TestTelegramEventObserver: assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1] def test_resolve_filters(self): - pass + router = Router() + observer = router.message_handler + observer.bind_filter(MyFilter1) + + resolved = observer.resolve_filters({"test": "PASS"}) + assert isinstance(resolved, list) + assert len(resolved) == 1 + assert isinstance(resolved[0], MyFilter1) + assert resolved[0].test == "PASS" + + # Unknown filter + with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"): + assert observer.resolve_filters({"@bad": "very"}) + + # Unknown filter + with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"): + assert observer.resolve_filters({"test": "ok", "@bad": "very"}) + + # Bad argument type + with pytest.raises(ValueError, match="Unknown filters: {'test'}"): + assert observer.resolve_filters({"test": ...}) def test_register(self): - pass + router = Router() + observer = router.message_handler + observer.bind_filter(MyFilter1) + + assert observer.register(my_handler) == my_handler + assert isinstance(observer.handlers[0], HandlerObject) + assert not observer.handlers[0].filters + + f = MyFilter1(test="ok") + observer.register(my_handler, f) + assert isinstance(observer.handlers[1], HandlerObject) + assert len(observer.handlers[1].filters) == 1 + assert observer.handlers[1].filters[0].callback == f + + observer.register(my_handler, test="PASS") + assert isinstance(observer.handlers[2], HandlerObject) + assert len(observer.handlers[2].filters) == 1 + assert observer.handlers[2].filters[0].callback == MyFilter1(test="PASS") + + observer.register(my_handler, f, test="PASS") + assert isinstance(observer.handlers[3], HandlerObject) + assert len(observer.handlers[3].filters) == 2 + assert observer.handlers[3].filters[0].callback == f + assert observer.handlers[3].filters[1].callback == MyFilter1(test="PASS") def test_register_decorator(self): - pass + router = Router() + observer = router.message_handler + + @observer() + async def my_handler(event: Any): + pass + + assert len(observer.handlers) == 1 + assert observer.handlers[0].callback == my_handler @pytest.mark.asyncio async def test_trigger(self): - pass + router = Router() + observer = router.message_handler + observer.bind_filter(MyFilter1) + observer.register(my_handler, test="ok") + + results = [result async for result in observer.trigger(42)] + assert results == [42]