From b82a1a6fb0663ad5baa54c50128afb1c62d12ceb Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Tue, 3 Dec 2019 00:03:15 +0200 Subject: [PATCH] Add prototype of class-based handlers --- aiogram/dispatcher/event/handler.py | 16 ++++++- aiogram/dispatcher/event/observer.py | 6 +-- aiogram/dispatcher/handler/__init__.py | 0 aiogram/dispatcher/handler/base.py | 37 +++++++++++++++ aiogram/dispatcher/handler/message.py | 16 +++++++ aiogram/utils/mixins.py | 8 +++- tests/conftest.py | 3 +- .../test_event/test_handler.py | 25 ++++++++++ .../test_dispatcher/test_handler/__init__.py | 0 .../test_dispatcher/test_handler/test_base.py | 47 +++++++++++++++++++ .../test_handler/test_message.py | 28 +++++++++++ 11 files changed, 178 insertions(+), 8 deletions(-) create mode 100644 aiogram/dispatcher/handler/__init__.py create mode 100644 aiogram/dispatcher/handler/base.py create mode 100644 aiogram/dispatcher/handler/message.py create mode 100644 tests/test_dispatcher/test_handler/__init__.py create mode 100644 tests/test_dispatcher/test_handler/test_base.py create mode 100644 tests/test_dispatcher/test_handler/test_message.py diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 08052fd3..cf2248ce 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -4,16 +4,18 @@ from functools import partial from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union from aiogram.dispatcher.filters.base import BaseFilter +from aiogram.dispatcher.handler.base import BaseHandler CallbackType = Callable[[Any], Awaitable[Any]] SyncFilter = Callable[[Any], Any] AsyncFilter = Callable[[Any], Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] +HandlerType = Union[CallbackType, BaseHandler] @dataclass class CallableMixin: - callback: Callable + callback: HandlerType awaitable: bool = field(init=False) spec: inspect.FullArgSpec = field(init=False) @@ -44,9 +46,19 @@ class FilterObject(CallableMixin): @dataclass class HandlerObject(CallableMixin): - callback: CallbackType + callback: HandlerType filters: List[FilterObject] + def __post_init__(self): + super(HandlerObject, self).__post_init__() + + if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): + self.awaitable = True + if hasattr(self.callback, "filters"): + self.filters.extend( + FilterObject(event_filter) for event_filter in self.callback.filters + ) + async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: for event_filter in self.filters: check = await event_filter.call(*args, **kwargs) diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 65d35eb1..299c8938 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Type from pydantic import ValidationError from ..filters.base import BaseFilter -from .handler import CallbackType, FilterObject, FilterType, HandlerObject +from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType if TYPE_CHECKING: # pragma: no cover from aiogram.dispatcher.router import Router @@ -24,7 +24,7 @@ class EventObserver: def __init__(self): self.handlers: List[HandlerObject] = [] - def register(self, callback: CallbackType, *filters: FilterType): + def register(self, callback: HandlerType, *filters: FilterType): """ Register callback with filters @@ -91,7 +91,7 @@ class TelegramEventObserver(EventObserver): yield filter_ registry.append(filter_) - def register(self, callback: CallbackType, *filters: FilterType, **bound_filters: Any): + def register(self, callback: HandlerType, *filters: FilterType, **bound_filters: Any): resolved_filters = self.resolve_filters(bound_filters) return super().register(callback, *filters, *resolved_filters) diff --git a/aiogram/dispatcher/handler/__init__.py b/aiogram/dispatcher/handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/dispatcher/handler/base.py b/aiogram/dispatcher/handler/base.py new file mode 100644 index 00000000..bfc16575 --- /dev/null +++ b/aiogram/dispatcher/handler/base.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +from aiogram import Bot +from aiogram.api.types import TelegramObject + +if TYPE_CHECKING: # pragma: no cover + from aiogram.dispatcher.event.handler import FilterType # NOQA: F401 + + +class BaseHandlerMixin: + event: TelegramObject + data: Dict[str, Any] + + +class HandlerBotMixin(BaseHandlerMixin): + @property + def bot(self) -> Bot: + if "bot" in self.data: + return self.data["bot"] + return Bot.get_current() + + +class BaseHandler(HandlerBotMixin, ABC): + event: TelegramObject + filters: Union[List["FilterType"], Tuple["FilterType"]] + + def __init__(self, event: TelegramObject, **kwargs: Any) -> None: + self.event = event + self.data = kwargs + + @abstractmethod + async def handle(self) -> Any: # pragma: no cover + pass + + def __await__(self): + return self.handle().__await__() diff --git a/aiogram/dispatcher/handler/message.py b/aiogram/dispatcher/handler/message.py new file mode 100644 index 00000000..25b9df6e --- /dev/null +++ b/aiogram/dispatcher/handler/message.py @@ -0,0 +1,16 @@ +from abc import ABC + +from aiogram.api.types import Message +from aiogram.dispatcher.handler.base import BaseHandler + + +class MessageHandler(BaseHandler, ABC): + event: Message + + @property + def from_user(self): + return self.event.from_user + + @property + def chat(self): + return self.event.chat diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py index eaaeb1a3..ca17b8d8 100644 --- a/aiogram/utils/mixins.py +++ b/aiogram/utils/mixins.py @@ -45,9 +45,13 @@ class ContextInstanceMixin: return cls.__context_instance.get() @classmethod - def set_current(cls: Type[T], value: T): + def set_current(cls: Type[T], value: T) -> contextvars.Token: if not isinstance(value, cls): raise TypeError( f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}" ) - cls.__context_instance.set(value) + return cls.__context_instance.set(value) + + @classmethod + def reset_current(cls: Type[T], token: contextvars.Token): + cls.__context_instance.reset(token) diff --git a/tests/conftest.py b/tests/conftest.py index cb02148f..beeb19a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,5 +7,6 @@ from tests.mocked_bot import MockedBot @pytest.fixture() def bot(): bot = MockedBot() - Bot.set_current(bot) + token = Bot.set_current(bot) yield bot + Bot.reset_current(token) diff --git a/tests/test_dispatcher/test_event/test_handler.py b/tests/test_dispatcher/test_event/test_handler.py index 2607ab20..bde90ed9 100644 --- a/tests/test_dispatcher/test_event/test_handler.py +++ b/tests/test_dispatcher/test_event/test_handler.py @@ -3,8 +3,11 @@ from typing import Any, Dict, Union import pytest +from aiogram.api.types import Update from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject +from aiogram.dispatcher.filters import Text from aiogram.dispatcher.filters.base import BaseFilter +from aiogram.dispatcher.handler.base import BaseHandler def callback1(foo: int, bar: int, baz: int): @@ -174,3 +177,25 @@ class TestHandlerObject: ) result, data = await handler.check(42, foo=True) assert not result + + @pytest.mark.asyncio + async def test_class_based_handler(self): + class MyFilter(BaseFilter): + async def __call__(self, event): + return True + + class MyHandler(BaseHandler): + event: Update + filters = [MyFilter()] + + async def handle(self) -> Any: + return self.event.update_id + + handler = HandlerObject(MyHandler, filters=[FilterObject(lambda event: True)]) + + assert handler.awaitable + assert handler.callback == MyHandler + assert len(handler.filters) == 2 + assert handler.filters[1].callback == MyFilter() + result = await handler.call(Update(update_id=42)) + assert result == 42 diff --git a/tests/test_dispatcher/test_handler/__init__.py b/tests/test_dispatcher/test_handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dispatcher/test_handler/test_base.py b/tests/test_dispatcher/test_handler/test_base.py new file mode 100644 index 00000000..b385d4e4 --- /dev/null +++ b/tests/test_dispatcher/test_handler/test_base.py @@ -0,0 +1,47 @@ +import asyncio +from typing import Any + +import pytest + +from aiogram import Bot +from aiogram.api.types import Update +from aiogram.dispatcher.handler.base import BaseHandler + + +class MyHandler(BaseHandler): + async def handle(self) -> Any: + await asyncio.sleep(0.1) + return 42 + + +class TestBaseClassBasedHandler: + @pytest.mark.asyncio + async def test_base_handler(self): + event = Update(update_id=42) + handler = MyHandler(event=event, key=42) + + assert handler.event == event + assert handler.data["key"] == 42 + assert hasattr(handler, "bot") + assert not hasattr(handler, "filters") + assert await handler == 42 + + @pytest.mark.asyncio + async def test_bot_mixin_from_context(self): + event = Update(update_id=42) + handler = MyHandler(event=event, key=42) + bot = Bot("42:TEST") + + assert handler.bot is None + + Bot.set_current(bot) + assert handler.bot == bot + + @pytest.mark.asyncio + async def test_bot_mixin_from_data(self): + event = Update(update_id=42) + bot = Bot("42:TEST") + handler = MyHandler(event=event, key=42, bot=bot) + + assert "bot" in handler.data + assert handler.bot == bot diff --git a/tests/test_dispatcher/test_handler/test_message.py b/tests/test_dispatcher/test_handler/test_message.py new file mode 100644 index 00000000..e5fd3194 --- /dev/null +++ b/tests/test_dispatcher/test_handler/test_message.py @@ -0,0 +1,28 @@ +import datetime +from typing import Any + +import pytest + +from aiogram.api.types import Chat, Message, User +from aiogram.dispatcher.handler.message import MessageHandler + + +class MyHandler(MessageHandler): + async def handle(self) -> Any: + return self.event.text + + +class TestClassBasedMessageHandler: + @pytest.mark.asyncio + async def test_message_handler(self): + event = Message( + message_id=42, + date=datetime.datetime.now(), + text="test", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ) + handler = MyHandler(event=event) + + assert handler.from_user == event.from_user + assert handler.chat == event.chat