Partially add tests for telegram event observer

This commit is contained in:
Alex Root Junior 2019-11-24 22:31:53 +02:00
parent 55f6c29ba6
commit 8204b6af52
3 changed files with 72 additions and 9 deletions

View file

@ -70,18 +70,18 @@ class TelegramEventObserver(EventObserver):
self.filters: List[Type[BaseFilter]] = []
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
if not isinstance(bound_filter, BaseFilter):
pass
if not issubclass(bound_filter, BaseFilter):
raise TypeError(
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
)
self.filters.append(bound_filter)
def _resolve_filters_chain(self):
registry: List[FilterType] = []
routers: List[Router] = []
router = self.router
while router and router not in routers:
while router:
observer = router.observers[self.event_name]
routers.append(router)
router = router.parent_router
for filter_ in observer.filters:

View file

@ -16,7 +16,7 @@ class BaseFilter(ABC, BaseModel):
@abstractmethod
async def __call__(
self, *args: Any, **kwargs: Any
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
) -> Union[bool, Dict[str, Any]]:
pass
def __await__(self):

View file

@ -1,10 +1,11 @@
import functools
from typing import Any, NoReturn
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
import pytest
from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.router import Router
async def my_handler(event: Any, index: int = 0) -> Any:
@ -106,3 +107,65 @@ class TestEventObserver:
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
with pytest.raises(TypeError):
event_observer.bind_filter(object) # type: ignore
class MyFilter(BaseFilter):
async def __call__(
self, *args: Any, **kwargs: Any
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
pass
event_observer.bind_filter(MyFilter)
assert event_observer.filters
assert event_observer.filters[0] == MyFilter
def test_resolve_filters_chain(self):
router1 = Router()
router2 = Router()
router3 = Router()
router1.include_router(router2)
router2.include_router(router3)
class MyFilter1(BaseFilter):
test: str
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class MyFilter2(MyFilter1):
pass
class MyFilter3(MyFilter1):
pass
router1.message_handler.bind_filter(MyFilter1)
router1.message_handler.bind_filter(MyFilter2)
router2.message_handler.bind_filter(MyFilter2)
router3.message_handler.bind_filter(MyFilter3)
filters_chain1 = list(router1.message_handler._resolve_filters_chain())
filters_chain2 = list(router2.message_handler._resolve_filters_chain())
filters_chain3 = list(router3.message_handler._resolve_filters_chain())
assert filters_chain1 == [MyFilter1, MyFilter2]
assert filters_chain2 == [MyFilter2, MyFilter1]
assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1]
def test_resolve_filters(self):
pass
def test_register(self):
pass
def test_register_decorator(self):
pass
@pytest.mark.asyncio
async def test_trigger(self):
pass