diff --git a/aiogram/api/client/base.py b/aiogram/api/client/base.py deleted file mode 100644 index c4f7aff3..00000000 --- a/aiogram/api/client/base.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations - -from contextlib import asynccontextmanager -from typing import ( - Any, - AsyncIterator, - Optional, - TypeVar, -) - -from ...utils.mixins import ( - ContextInstance, - ContextInstanceMixin, -) -from ...utils.token import extract_bot_id, validate_token -from ..methods import TelegramMethod -from .session.aiohttp import AiohttpSession -from .session.base import BaseSession - -T = TypeVar("T") - - -class BaseBot(ContextInstanceMixin[ContextInstance]): - """ - Base class for bots - """ - - def __init__( - self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None - ) -> None: - validate_token(token) - - if session is None: - session = AiohttpSession() - - self.session = session - self.parse_mode = parse_mode - self.__token = token - - @property - def id(self) -> int: - """ - Get bot ID from token - - :return: - """ - return extract_bot_id(self.__token) - - async def __call__(self, method: TelegramMethod[T]) -> T: - """ - Call API method - - :param method: - :return: - """ - return await self.session.make_request(self.__token, method) - - async def close(self) -> None: - """ - Close bot session - """ - await self.session.close() - - @asynccontextmanager - async def context(self, auto_close: bool = True) -> AsyncIterator["BaseBot[ContextInstance]"]: - """ - Generate bot context - - :param auto_close: - :return: - """ - # TODO: because set_current expects Bot, not BaseBot — this check fails - token = self.set_current(self) # type: ignore - try: - yield self - finally: - if auto_close: - await self.close() - self.reset_current(token) - - def __hash__(self) -> int: - """ - Get hash for the token - - :return: - """ - return hash(self.__token) - - def __eq__(self, other: Any) -> bool: - """ - Compare current bot with another bot instance - - :param other: - :return: - """ - if not isinstance(other, BaseBot): - return False - return hash(self) == hash(other) diff --git a/aiogram/api/client/bot.py b/aiogram/api/client/bot.py index 51058f1d..5e58b4f8 100644 --- a/aiogram/api/client/bot.py +++ b/aiogram/api/client/bot.py @@ -1,8 +1,20 @@ +from __future__ import annotations + import datetime -from typing import List, Optional, Union +from contextlib import asynccontextmanager +from typing import ( + List, + Optional, + Union, + TypeVar, + AsyncIterator, + Any, +) from async_lru import alru_cache +from .session.aiohttp import AiohttpSession +from .session.base import BaseSession from ..methods import ( AddStickerToSet, AnswerCallbackQuery, @@ -70,6 +82,7 @@ from ..methods import ( UnbanChatMember, UnpinChatMessage, UploadStickerFile, + TelegramMethod, ) from ..types import ( Chat, @@ -98,14 +111,93 @@ from ..types import ( UserProfilePhotos, WebhookInfo, ) -from .base import BaseBot +from ...utils.mixins import ( + ContextInstanceMixin, +) +from ...utils.token import ( + validate_token, + extract_bot_id, +) + +T = TypeVar("T") -class Bot(BaseBot["Bot"]): +class Bot(ContextInstanceMixin["Bot"]): """ Class where located all API methods """ + def __init__( + self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None + ) -> None: + validate_token(token) + + if session is None: + session = AiohttpSession() + + self.session = session + self.parse_mode = parse_mode + self.__token = token + + @property + def id(self) -> int: + """ + Get bot ID from token + + :return: + """ + return extract_bot_id(self.__token) + + async def __call__(self, method: TelegramMethod[T]) -> T: + """ + Call API method + + :param method: + :return: + """ + return await self.session.make_request(self.__token, method) + + async def close(self) -> None: + """ + Close bot session + """ + await self.session.close() + + @asynccontextmanager + async def context(self, auto_close: bool = True) -> AsyncIterator[Bot]: + """ + Generate bot context + + :param auto_close: + :return: + """ + token = self.set_current(self) + try: + yield self + finally: + if auto_close: + await self.close() + self.reset_current(token) + + def __hash__(self) -> int: + """ + Get hash for the token + + :return: + """ + return hash(self.__token) + + def __eq__(self, other: Any) -> bool: + """ + Compare current bot with another bot instance + + :param other: + :return: + """ + if not isinstance(other, Bot): + return False + return hash(self) == hash(other) + @alru_cache() # type: ignore async def me(self) -> User: return await self.get_me() diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index fa24f259..5df6f28d 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -10,6 +10,7 @@ from typing import ( Optional, Tuple, Union, + Type, ) from aiogram.dispatcher.filters.base import BaseFilter @@ -19,7 +20,7 @@ CallbackType = Callable[[Any], Awaitable[Any]] SyncFilter = Callable[[Any], Any] AsyncFilter = Callable[[Any], Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] -HandlerType = Union[FilterType, BaseHandler] +HandlerType = Union[FilterType, Type[BaseHandler]] @dataclass @@ -42,8 +43,7 @@ class CallableMixin: return {k: v for k, v in kwargs.items() if k in self.spec.args} async def call(self, *args: Any, **kwargs: Any) -> Any: - # TODO: what we should do if callback is BaseHandler? - wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) # type: ignore + wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) if self.awaitable: return await wrapped() return wrapped() @@ -61,7 +61,6 @@ class HandlerObject(CallableMixin): def __post_init__(self) -> None: super(HandlerObject, self).__post_init__() - # TODO: by types callback must be Callable or BaseHandler, not Type[BaseHandler] if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore self.awaitable = True diff --git a/aiogram/dispatcher/handler/base.py b/aiogram/dispatcher/handler/base.py index 21bc248b..45d3f28e 100644 --- a/aiogram/dispatcher/handler/base.py +++ b/aiogram/dispatcher/handler/base.py @@ -31,15 +31,13 @@ class BaseHandler(BaseHandlerMixin[T], ABC): self.data: Dict[str, Any] = kwargs @property - def bot(self) -> Optional[Bot]: + def bot(self) -> Bot: if "bot" in self.data: - # TODO: remove cast return cast(Bot, self.data["bot"]) - return Bot.get_current() + return Bot.get_current(no_error=False) @property def update(self) -> Update: - # TODO: remove cast return cast(Update, self.data["update"]) @abstractmethod diff --git a/tests/test_api/test_client/test_base_bot.py b/tests/test_api/test_client/test_base_bot.py index 652f0918..a254bfaf 100644 --- a/tests/test_api/test_client/test_base_bot.py +++ b/tests/test_api/test_client/test_base_bot.py @@ -1,6 +1,6 @@ import pytest -from aiogram.api.client.base import BaseBot +from aiogram import Bot from aiogram.api.client.session.aiohttp import AiohttpSession from aiogram.api.methods import GetMe @@ -12,22 +12,22 @@ except ImportError: class TestBaseBot: def test_init(self): - base_bot = BaseBot("42:TEST") + base_bot = Bot("42:TEST") assert isinstance(base_bot.session, AiohttpSession) assert base_bot.id == 42 def test_hashable(self): - base_bot = BaseBot("42:TEST") + base_bot = Bot("42:TEST") assert hash(base_bot) == hash("42:TEST") def test_equals(self): - base_bot = BaseBot("42:TEST") - assert base_bot == BaseBot("42:TEST") + base_bot = Bot("42:TEST") + assert base_bot == Bot("42:TEST") assert base_bot != "42:TEST" @pytest.mark.asyncio async def test_emit(self): - base_bot = BaseBot("42:TEST") + base_bot = Bot("42:TEST") method = GetMe() @@ -40,7 +40,7 @@ class TestBaseBot: @pytest.mark.asyncio async def test_close(self): - base_bot = BaseBot("42:TEST", session=AiohttpSession()) + base_bot = Bot("42:TEST", session=AiohttpSession()) await base_bot.session.create_session() with patch( @@ -55,10 +55,10 @@ class TestBaseBot: with patch( "aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock ) as mocked_close: - async with BaseBot("42:TEST", session=AiohttpSession()).context( + async with Bot("42:TEST", session=AiohttpSession()).context( auto_close=close ) as bot: - assert isinstance(bot, BaseBot) + assert isinstance(bot, Bot) if close: mocked_close.assert_awaited() else: diff --git a/tests/test_dispatcher/test_handler/test_base.py b/tests/test_dispatcher/test_handler/test_base.py index 5e8ef0f3..21063b62 100644 --- a/tests/test_dispatcher/test_handler/test_base.py +++ b/tests/test_dispatcher/test_handler/test_base.py @@ -23,7 +23,6 @@ class TestBaseClassBasedHandler: assert handler.event == event assert handler.data["key"] == 42 - assert hasattr(handler, "bot") assert not hasattr(handler, "filters") assert await handler == 42 @@ -33,7 +32,8 @@ class TestBaseClassBasedHandler: handler = MyHandler(event=event, key=42) bot = Bot("42:TEST") - assert handler.bot is None + with pytest.raises(LookupError): + handler.bot Bot.set_current(bot) assert handler.bot == bot