mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 01:15:31 +00:00
Merge BaseBot to Bot class
This commit is contained in:
parent
fac69e52b7
commit
a823e275a7
6 changed files with 111 additions and 120 deletions
|
|
@ -1,98 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Optional,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ...utils.mixins import (
|
|
||||||
ContextInstance,
|
|
||||||
ContextInstanceMixin,
|
|
||||||
)
|
|
||||||
from ...utils.token import extract_bot_id, validate_token
|
|
||||||
from ..methods import TelegramMethod
|
|
||||||
from .session.aiohttp import AiohttpSession
|
|
||||||
from .session.base import BaseSession
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBot(ContextInstanceMixin[ContextInstance]):
|
|
||||||
"""
|
|
||||||
Base class for bots
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
validate_token(token)
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = AiohttpSession()
|
|
||||||
|
|
||||||
self.session = session
|
|
||||||
self.parse_mode = parse_mode
|
|
||||||
self.__token = token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> int:
|
|
||||||
"""
|
|
||||||
Get bot ID from token
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return extract_bot_id(self.__token)
|
|
||||||
|
|
||||||
async def __call__(self, method: TelegramMethod[T]) -> T:
|
|
||||||
"""
|
|
||||||
Call API method
|
|
||||||
|
|
||||||
:param method:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return await self.session.make_request(self.__token, method)
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""
|
|
||||||
Close bot session
|
|
||||||
"""
|
|
||||||
await self.session.close()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def context(self, auto_close: bool = True) -> AsyncIterator["BaseBot[ContextInstance]"]:
|
|
||||||
"""
|
|
||||||
Generate bot context
|
|
||||||
|
|
||||||
:param auto_close:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# TODO: because set_current expects Bot, not BaseBot — this check fails
|
|
||||||
token = self.set_current(self) # type: ignore
|
|
||||||
try:
|
|
||||||
yield self
|
|
||||||
finally:
|
|
||||||
if auto_close:
|
|
||||||
await self.close()
|
|
||||||
self.reset_current(token)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""
|
|
||||||
Get hash for the token
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return hash(self.__token)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""
|
|
||||||
Compare current bot with another bot instance
|
|
||||||
|
|
||||||
:param other:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if not isinstance(other, BaseBot):
|
|
||||||
return False
|
|
||||||
return hash(self) == hash(other)
|
|
||||||
|
|
@ -1,8 +1,20 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Optional, Union
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
TypeVar,
|
||||||
|
AsyncIterator,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
from async_lru import alru_cache
|
from async_lru import alru_cache
|
||||||
|
|
||||||
|
from .session.aiohttp import AiohttpSession
|
||||||
|
from .session.base import BaseSession
|
||||||
from ..methods import (
|
from ..methods import (
|
||||||
AddStickerToSet,
|
AddStickerToSet,
|
||||||
AnswerCallbackQuery,
|
AnswerCallbackQuery,
|
||||||
|
|
@ -70,6 +82,7 @@ from ..methods import (
|
||||||
UnbanChatMember,
|
UnbanChatMember,
|
||||||
UnpinChatMessage,
|
UnpinChatMessage,
|
||||||
UploadStickerFile,
|
UploadStickerFile,
|
||||||
|
TelegramMethod,
|
||||||
)
|
)
|
||||||
from ..types import (
|
from ..types import (
|
||||||
Chat,
|
Chat,
|
||||||
|
|
@ -98,14 +111,93 @@ from ..types import (
|
||||||
UserProfilePhotos,
|
UserProfilePhotos,
|
||||||
WebhookInfo,
|
WebhookInfo,
|
||||||
)
|
)
|
||||||
from .base import BaseBot
|
from ...utils.mixins import (
|
||||||
|
ContextInstanceMixin,
|
||||||
|
)
|
||||||
|
from ...utils.token import (
|
||||||
|
validate_token,
|
||||||
|
extract_bot_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Bot(BaseBot["Bot"]):
|
class Bot(ContextInstanceMixin["Bot"]):
|
||||||
"""
|
"""
|
||||||
Class where located all API methods
|
Class where located all API methods
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, token: str, session: Optional[BaseSession] = None, parse_mode: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
validate_token(token)
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = AiohttpSession()
|
||||||
|
|
||||||
|
self.session = session
|
||||||
|
self.parse_mode = parse_mode
|
||||||
|
self.__token = token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> int:
|
||||||
|
"""
|
||||||
|
Get bot ID from token
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return extract_bot_id(self.__token)
|
||||||
|
|
||||||
|
async def __call__(self, method: TelegramMethod[T]) -> T:
|
||||||
|
"""
|
||||||
|
Call API method
|
||||||
|
|
||||||
|
:param method:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.session.make_request(self.__token, method)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close bot session
|
||||||
|
"""
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def context(self, auto_close: bool = True) -> AsyncIterator[Bot]:
|
||||||
|
"""
|
||||||
|
Generate bot context
|
||||||
|
|
||||||
|
:param auto_close:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
token = self.set_current(self)
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
if auto_close:
|
||||||
|
await self.close()
|
||||||
|
self.reset_current(token)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get hash for the token
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return hash(self.__token)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Compare current bot with another bot instance
|
||||||
|
|
||||||
|
:param other:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not isinstance(other, Bot):
|
||||||
|
return False
|
||||||
|
return hash(self) == hash(other)
|
||||||
|
|
||||||
@alru_cache() # type: ignore
|
@alru_cache() # type: ignore
|
||||||
async def me(self) -> User:
|
async def me(self) -> User:
|
||||||
return await self.get_me()
|
return await self.get_me()
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
from aiogram.dispatcher.filters.base import BaseFilter
|
from aiogram.dispatcher.filters.base import BaseFilter
|
||||||
|
|
@ -19,7 +20,7 @@ CallbackType = Callable[[Any], Awaitable[Any]]
|
||||||
SyncFilter = Callable[[Any], Any]
|
SyncFilter = Callable[[Any], Any]
|
||||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||||
HandlerType = Union[FilterType, BaseHandler]
|
HandlerType = Union[FilterType, Type[BaseHandler]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -42,8 +43,7 @@ class CallableMixin:
|
||||||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||||
|
|
||||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
# TODO: what we should do if callback is BaseHandler?
|
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
|
||||||
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) # type: ignore
|
|
||||||
if self.awaitable:
|
if self.awaitable:
|
||||||
return await wrapped()
|
return await wrapped()
|
||||||
return wrapped()
|
return wrapped()
|
||||||
|
|
@ -61,7 +61,6 @@ class HandlerObject(CallableMixin):
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super(HandlerObject, self).__post_init__()
|
super(HandlerObject, self).__post_init__()
|
||||||
# TODO: by types callback must be Callable or BaseHandler, not Type[BaseHandler]
|
|
||||||
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore
|
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore
|
||||||
self.awaitable = True
|
self.awaitable = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,15 +31,13 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
|
||||||
self.data: Dict[str, Any] = kwargs
|
self.data: Dict[str, Any] = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> Optional[Bot]:
|
def bot(self) -> Bot:
|
||||||
if "bot" in self.data:
|
if "bot" in self.data:
|
||||||
# TODO: remove cast
|
|
||||||
return cast(Bot, self.data["bot"])
|
return cast(Bot, self.data["bot"])
|
||||||
return Bot.get_current()
|
return Bot.get_current(no_error=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def update(self) -> Update:
|
def update(self) -> Update:
|
||||||
# TODO: remove cast
|
|
||||||
return cast(Update, self.data["update"])
|
return cast(Update, self.data["update"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiogram.api.client.base import BaseBot
|
from aiogram import Bot
|
||||||
from aiogram.api.client.session.aiohttp import AiohttpSession
|
from aiogram.api.client.session.aiohttp import AiohttpSession
|
||||||
from aiogram.api.methods import GetMe
|
from aiogram.api.methods import GetMe
|
||||||
|
|
||||||
|
|
@ -12,22 +12,22 @@ except ImportError:
|
||||||
|
|
||||||
class TestBaseBot:
|
class TestBaseBot:
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
base_bot = BaseBot("42:TEST")
|
base_bot = Bot("42:TEST")
|
||||||
assert isinstance(base_bot.session, AiohttpSession)
|
assert isinstance(base_bot.session, AiohttpSession)
|
||||||
assert base_bot.id == 42
|
assert base_bot.id == 42
|
||||||
|
|
||||||
def test_hashable(self):
|
def test_hashable(self):
|
||||||
base_bot = BaseBot("42:TEST")
|
base_bot = Bot("42:TEST")
|
||||||
assert hash(base_bot) == hash("42:TEST")
|
assert hash(base_bot) == hash("42:TEST")
|
||||||
|
|
||||||
def test_equals(self):
|
def test_equals(self):
|
||||||
base_bot = BaseBot("42:TEST")
|
base_bot = Bot("42:TEST")
|
||||||
assert base_bot == BaseBot("42:TEST")
|
assert base_bot == Bot("42:TEST")
|
||||||
assert base_bot != "42:TEST"
|
assert base_bot != "42:TEST"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_emit(self):
|
async def test_emit(self):
|
||||||
base_bot = BaseBot("42:TEST")
|
base_bot = Bot("42:TEST")
|
||||||
|
|
||||||
method = GetMe()
|
method = GetMe()
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ class TestBaseBot:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_close(self):
|
async def test_close(self):
|
||||||
base_bot = BaseBot("42:TEST", session=AiohttpSession())
|
base_bot = Bot("42:TEST", session=AiohttpSession())
|
||||||
await base_bot.session.create_session()
|
await base_bot.session.create_session()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
|
|
@ -55,10 +55,10 @@ class TestBaseBot:
|
||||||
with patch(
|
with patch(
|
||||||
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
|
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
|
||||||
) as mocked_close:
|
) as mocked_close:
|
||||||
async with BaseBot("42:TEST", session=AiohttpSession()).context(
|
async with Bot("42:TEST", session=AiohttpSession()).context(
|
||||||
auto_close=close
|
auto_close=close
|
||||||
) as bot:
|
) as bot:
|
||||||
assert isinstance(bot, BaseBot)
|
assert isinstance(bot, Bot)
|
||||||
if close:
|
if close:
|
||||||
mocked_close.assert_awaited()
|
mocked_close.assert_awaited()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ class TestBaseClassBasedHandler:
|
||||||
|
|
||||||
assert handler.event == event
|
assert handler.event == event
|
||||||
assert handler.data["key"] == 42
|
assert handler.data["key"] == 42
|
||||||
assert hasattr(handler, "bot")
|
|
||||||
assert not hasattr(handler, "filters")
|
assert not hasattr(handler, "filters")
|
||||||
assert await handler == 42
|
assert await handler == 42
|
||||||
|
|
||||||
|
|
@ -33,7 +32,8 @@ class TestBaseClassBasedHandler:
|
||||||
handler = MyHandler(event=event, key=42)
|
handler = MyHandler(event=event, key=42)
|
||||||
bot = Bot("42:TEST")
|
bot = Bot("42:TEST")
|
||||||
|
|
||||||
assert handler.bot is None
|
with pytest.raises(LookupError):
|
||||||
|
handler.bot
|
||||||
|
|
||||||
Bot.set_current(bot)
|
Bot.set_current(bot)
|
||||||
assert handler.bot == bot
|
assert handler.bot == bot
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue