Merge BaseBot to Bot class

This commit is contained in:
Boger 2020-03-25 15:35:32 +03:00
parent fac69e52b7
commit a823e275a7
6 changed files with 111 additions and 120 deletions

View file

@ -1,98 +0,0 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncIterator,
Optional,
TypeVar,
)
from ...utils.mixins import (
ContextInstance,
ContextInstanceMixin,
)
from ...utils.token import extract_bot_id, validate_token
from ..methods import TelegramMethod
from .session.aiohttp import AiohttpSession
from .session.base import BaseSession
T = TypeVar("T")
class BaseBot(ContextInstanceMixin[ContextInstance]):
"""
Base class for bots
"""
def __init__(
self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None
) -> None:
validate_token(token)
if session is None:
session = AiohttpSession()
self.session = session
self.parse_mode = parse_mode
self.__token = token
@property
def id(self) -> int:
"""
Get bot ID from token
:return:
"""
return extract_bot_id(self.__token)
async def __call__(self, method: TelegramMethod[T]) -> T:
"""
Call API method
:param method:
:return:
"""
return await self.session.make_request(self.__token, method)
async def close(self) -> None:
"""
Close bot session
"""
await self.session.close()
@asynccontextmanager
async def context(self, auto_close: bool = True) -> AsyncIterator["BaseBot[ContextInstance]"]:
"""
Generate bot context
:param auto_close:
:return:
"""
# TODO: because set_current expects Bot, not BaseBot — this check fails
token = self.set_current(self) # type: ignore
try:
yield self
finally:
if auto_close:
await self.close()
self.reset_current(token)
def __hash__(self) -> int:
"""
Get hash for the token
:return:
"""
return hash(self.__token)
def __eq__(self, other: Any) -> bool:
"""
Compare current bot with another bot instance
:param other:
:return:
"""
if not isinstance(other, BaseBot):
return False
return hash(self) == hash(other)

View file

@ -1,8 +1,20 @@
from __future__ import annotations
import datetime
from typing import List, Optional, Union
from contextlib import asynccontextmanager
from typing import (
List,
Optional,
Union,
TypeVar,
AsyncIterator,
Any,
)
from async_lru import alru_cache
from .session.aiohttp import AiohttpSession
from .session.base import BaseSession
from ..methods import (
AddStickerToSet,
AnswerCallbackQuery,
@ -70,6 +82,7 @@ from ..methods import (
UnbanChatMember,
UnpinChatMessage,
UploadStickerFile,
TelegramMethod,
)
from ..types import (
Chat,
@ -98,14 +111,93 @@ from ..types import (
UserProfilePhotos,
WebhookInfo,
)
from .base import BaseBot
from ...utils.mixins import (
ContextInstanceMixin,
)
from ...utils.token import (
validate_token,
extract_bot_id,
)
T = TypeVar("T")
class Bot(BaseBot["Bot"]):
class Bot(ContextInstanceMixin["Bot"]):
"""
Class where located all API methods
"""
def __init__(
self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None
) -> None:
validate_token(token)
if session is None:
session = AiohttpSession()
self.session = session
self.parse_mode = parse_mode
self.__token = token
@property
def id(self) -> int:
"""
Get bot ID from token
:return:
"""
return extract_bot_id(self.__token)
async def __call__(self, method: TelegramMethod[T]) -> T:
"""
Call API method
:param method:
:return:
"""
return await self.session.make_request(self.__token, method)
async def close(self) -> None:
"""
Close bot session
"""
await self.session.close()
@asynccontextmanager
async def context(self, auto_close: bool = True) -> AsyncIterator[Bot]:
"""
Generate bot context
:param auto_close:
:return:
"""
token = self.set_current(self)
try:
yield self
finally:
if auto_close:
await self.close()
self.reset_current(token)
def __hash__(self) -> int:
"""
Get hash for the token
:return:
"""
return hash(self.__token)
def __eq__(self, other: Any) -> bool:
"""
Compare current bot with another bot instance
:param other:
:return:
"""
if not isinstance(other, Bot):
return False
return hash(self) == hash(other)
@alru_cache() # type: ignore
async def me(self) -> User:
return await self.get_me()

View file

@ -10,6 +10,7 @@ from typing import (
Optional,
Tuple,
Union,
Type,
)
from aiogram.dispatcher.filters.base import BaseFilter
@ -19,7 +20,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
SyncFilter = Callable[[Any], Any]
AsyncFilter = Callable[[Any], Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[FilterType, BaseHandler]
HandlerType = Union[FilterType, Type[BaseHandler]]
@dataclass
@ -42,8 +43,7 @@ class CallableMixin:
return {k: v for k, v in kwargs.items() if k in self.spec.args}
async def call(self, *args: Any, **kwargs: Any) -> Any:
# TODO: what we should do if callback is BaseHandler?
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) # type: ignore
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
if self.awaitable:
return await wrapped()
return wrapped()
@ -61,7 +61,6 @@ class HandlerObject(CallableMixin):
def __post_init__(self) -> None:
super(HandlerObject, self).__post_init__()
# TODO: by types callback must be Callable or BaseHandler, not Type[BaseHandler]
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore
self.awaitable = True

View file

@ -31,15 +31,13 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
self.data: Dict[str, Any] = kwargs
@property
def bot(self) -> Optional[Bot]:
def bot(self) -> Bot:
if "bot" in self.data:
# TODO: remove cast
return cast(Bot, self.data["bot"])
return Bot.get_current()
return Bot.get_current(no_error=False)
@property
def update(self) -> Update:
# TODO: remove cast
return cast(Update, self.data["update"])
@abstractmethod

View file

@ -1,6 +1,6 @@
import pytest
from aiogram.api.client.base import BaseBot
from aiogram import Bot
from aiogram.api.client.session.aiohttp import AiohttpSession
from aiogram.api.methods import GetMe
@ -12,22 +12,22 @@ except ImportError:
class TestBaseBot:
def test_init(self):
base_bot = BaseBot("42:TEST")
base_bot = Bot("42:TEST")
assert isinstance(base_bot.session, AiohttpSession)
assert base_bot.id == 42
def test_hashable(self):
base_bot = BaseBot("42:TEST")
base_bot = Bot("42:TEST")
assert hash(base_bot) == hash("42:TEST")
def test_equals(self):
base_bot = BaseBot("42:TEST")
assert base_bot == BaseBot("42:TEST")
base_bot = Bot("42:TEST")
assert base_bot == Bot("42:TEST")
assert base_bot != "42:TEST"
@pytest.mark.asyncio
async def test_emit(self):
base_bot = BaseBot("42:TEST")
base_bot = Bot("42:TEST")
method = GetMe()
@ -40,7 +40,7 @@ class TestBaseBot:
@pytest.mark.asyncio
async def test_close(self):
base_bot = BaseBot("42:TEST", session=AiohttpSession())
base_bot = Bot("42:TEST", session=AiohttpSession())
await base_bot.session.create_session()
with patch(
@ -55,10 +55,10 @@ class TestBaseBot:
with patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
) as mocked_close:
async with BaseBot("42:TEST", session=AiohttpSession()).context(
async with Bot("42:TEST", session=AiohttpSession()).context(
auto_close=close
) as bot:
assert isinstance(bot, BaseBot)
assert isinstance(bot, Bot)
if close:
mocked_close.assert_awaited()
else:

View file

@ -23,7 +23,6 @@ class TestBaseClassBasedHandler:
assert handler.event == event
assert handler.data["key"] == 42
assert hasattr(handler, "bot")
assert not hasattr(handler, "filters")
assert await handler == 42
@ -33,7 +32,8 @@ class TestBaseClassBasedHandler:
handler = MyHandler(event=event, key=42)
bot = Bot("42:TEST")
assert handler.bot is None
with pytest.raises(LookupError):
handler.bot
Bot.set_current(bot)
assert handler.bot == bot