Add token validation util, fix deepcopy of sessions and make bot hashable and comparable

This commit is contained in:
Alex Root Junior 2019-11-28 23:12:44 +02:00
parent 9adc2f91bd
commit c674b5547b
11 changed files with 223 additions and 41 deletions

View file

@ -1,6 +1,10 @@
from typing import TypeVar
from __future__ import annotations
import copy
from typing import TypeVar, Dict, Any
from ...utils.mixins import ContextInstanceMixin
from ...utils.token import extract_bot_id, validate_token
from ..methods import TelegramMethod
from .session.aiohttp import AiohttpSession
from .session.base import BaseSession
@ -10,13 +14,20 @@ T = TypeVar("T")
class BaseBot(ContextInstanceMixin):
def __init__(self, token: str, session: BaseSession = None):
validate_token(token)
if session is None:
session = AiohttpSession()
self.session = session
self.token = token
self.__token = token
@property
def id(self):
return extract_bot_id(self.__token)
async def emit(self, method: TelegramMethod[T]) -> T:
return await self.session.make_request(self.token, method)
return await self.session.make_request(self.__token, method)
async def close(self):
await self.session.close()
@ -26,3 +37,11 @@ class BaseBot(ContextInstanceMixin):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()
def __hash__(self):
return hash(self.__token)
def __eq__(self, other: BaseBot):
if not isinstance(other, BaseBot):
return False
return hash(self) == hash(other)

View file

@ -1,4 +1,7 @@
from typing import Callable, Optional, TypeVar, cast
from __future__ import annotations
import copy
from typing import Callable, Optional, TypeVar, cast, Dict, Any
from aiohttp import ClientSession, FormData
@ -53,3 +56,17 @@ class AiohttpSession(BaseSession):
response = call.build_response(raw_result)
self.raise_for_status(response)
return cast(T, response.result)
async def __aenter__(self) -> AiohttpSession:
await self.create_session()
return self
def __deepcopy__(self, memodict: Dict[str, Any]):
cls = self.__class__
result = cls.__new__(cls)
memodict[id(self)] = result
for key, value in self.__dict__.items():
# aiohttp ClientSession cannot be copied.
copied_value = copy.deepcopy(value, memo=memodict) if key != '_session' else None
setattr(result, key, copied_value)
return result

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
import datetime
import json
@ -13,9 +15,9 @@ class BaseSession(abc.ABC):
def __init__(
self,
api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable] = None,
json_dumps: Optional[Callable] = None,
):
json_loads: Optional[Callable[[Any], Any]] = None,
json_dumps: Optional[Callable[[Any], Any]] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
@ -27,7 +29,7 @@ class BaseSession(abc.ABC):
self.json_loads = json_loads
self.json_dumps = json_dumps
def raise_for_status(self, response: Response[T]):
def raise_for_status(self, response: Response[T]) -> None:
if response.ok:
return
raise Exception(response.description)
@ -37,7 +39,7 @@ class BaseSession(abc.ABC):
pass
@abc.abstractmethod
async def make_request(self, token: str, method: TelegramMethod[T]) -> T:
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
pass
def prepare_value(self, value: Any) -> Union[str, int, bool]:
@ -53,9 +55,15 @@ class BaseSession(abc.ABC):
else:
return str(value)
def clean_json(self, value: Any):
def clean_json(self, value: Any) -> Any:
if isinstance(value, list):
return [self.clean_json(v) for v in value if v is not None]
elif isinstance(value, dict):
return {k: self.clean_json(v) for k, v in value.items() if v is not None}
return value
async def __aenter__(self) -> BaseSession:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

42
aiogram/utils/token.py Normal file
View file

@ -0,0 +1,42 @@
from functools import lru_cache
class TokenValidationError(Exception):
pass
@lru_cache()
def validate_token(token: str) -> bool:
"""
Validate Telegram token
:param token:
:return:
"""
if not isinstance(token, str):
raise TokenValidationError(
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
)
if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces."
raise TokenValidationError(message)
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right):
raise TokenValidationError("Token is invalid!")
return True
@lru_cache()
def extract_bot_id(token: str) -> int:
"""
Extract bot ID from Telegram token
:param token:
:return:
"""
validate_token(token)
raw_bot_id, *_ = token.split(":")
return int(raw_bot_id)

View file

@ -35,7 +35,7 @@ class MockedBot(Bot):
session: MockedSession
def __init__(self):
super(MockedBot, self).__init__("TOKEN", session=MockedSession())
super(MockedBot, self).__init__("42:TEST", session=MockedSession())
def add_result_for(
self,

View file

@ -3,9 +3,9 @@ from aiogram.api.client.telegram import PRODUCTION
class TestAPIServer:
def test_method_url(self):
method_url = PRODUCTION.api_url(token="TOKEN", method="apiMethod")
assert method_url == "https://api.telegram.org/botTOKEN/apiMethod"
method_url = PRODUCTION.api_url(token="42:TEST", method="apiMethod")
assert method_url == "https://api.telegram.org/bot42:TEST/apiMethod"
def test_file_url(self):
file_url = PRODUCTION.file_url(token="TOKEN", path="path")
assert file_url == "https://api.telegram.org/file/botTOKEN/path"
file_url = PRODUCTION.file_url(token="42:TEST", path="path")
assert file_url == "https://api.telegram.org/file/bot42:TEST/path"

View file

@ -1,3 +1,5 @@
import copy
import pytest
from asynctest import CoroutineMock, patch
@ -9,12 +11,21 @@ from aiogram.api.methods import GetMe
class TestBaseBot:
def test_init(self):
base_bot = BaseBot("TOKEN")
base_bot = BaseBot("42:TEST")
assert isinstance(base_bot.session, AiohttpSession)
assert base_bot.id == 42
def test_hashable(self):
base_bot = BaseBot("42:TEST")
assert hash(base_bot) == hash("42:TEST")
def test_equals(self):
base_bot = BaseBot("42:TEST")
assert base_bot == BaseBot("42:TEST")
@pytest.mark.asyncio
async def test_emit(self):
base_bot = BaseBot("TOKEN")
base_bot = BaseBot("42:TEST")
method = GetMe()
@ -23,11 +34,11 @@ class TestBaseBot:
new_callable=CoroutineMock,
) as mocked_make_request:
await base_bot.emit(method)
mocked_make_request.assert_awaited_with("TOKEN", method)
mocked_make_request.assert_awaited_with("42:TEST", method)
@pytest.mark.asyncio
async def test_close(self):
base_bot = BaseBot("TOKEN", session=AiohttpSession())
base_bot = BaseBot("42:TEST", session=AiohttpSession())
await base_bot.session.create_session()
with patch(
@ -41,6 +52,6 @@ class TestBaseBot:
with patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
) as mocked_close:
async with BaseBot("TOKEN", session=AiohttpSession()) as bot:
async with BaseBot("42:TEST", session=AiohttpSession()) as bot:
assert isinstance(bot, BaseBot)
mocked_close.assert_awaited()

View file

@ -1,3 +1,6 @@
import copy
from typing import AsyncContextManager
import aiohttp
import pytest
from aresponses import ResponsesMockServer
@ -74,7 +77,7 @@ class TestAiohttpSession:
async def test_make_request(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY,
"/botTOKEN/method",
"/bot42:TEST/method",
"post",
aresponses.Response(
status=200,
@ -95,8 +98,32 @@ class TestAiohttpSession:
with patch(
"aiogram.api.client.session.base.BaseSession.raise_for_status"
) as patched_raise_for_status:
result = await session.make_request("TOKEN", call)
result = await session.make_request("42:TEST", call)
assert isinstance(result, int)
assert result == 42
assert patched_raise_for_status.called_once()
@pytest.mark.asyncio
async def test_context_manager(self):
session = AiohttpSession()
assert isinstance(session, AsyncContextManager)
with patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.create_session",
new_callable=CoroutineMock,
) as mocked_create_session, patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.awaited_once()
mocked_create_session.awaited_once()
@pytest.mark.asyncio
async def test_deepcopy(self):
# Session should be copied without aiohttp.ClientSession
async with AiohttpSession() as session:
cloned_session = copy.deepcopy(session)
assert cloned_session != session
assert cloned_session._session is None

View file

@ -1,23 +1,27 @@
import datetime
from typing import AsyncContextManager
import pytest
from asynctest import CoroutineMock, patch
from aiogram.api.client.session.base import BaseSession
from aiogram.api.client.session.base import BaseSession, T
from aiogram.api.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.api.methods import GetMe, Response
from aiogram.api.methods import GetMe, Response, TelegramMethod
from aiogram.utils.mixins import DataMixin
class CustomSession(BaseSession):
async def close(self):
pass
async def make_request(self, token: str, method: TelegramMethod[T]) -> None: # type: ignore
assert isinstance(token, str)
assert isinstance(method, TelegramMethod)
class TestBaseSession(DataMixin):
def setup(self):
self["__abstractmethods__"] = BaseSession.__abstractmethods__
BaseSession.__abstractmethods__ = set()
def teardown(self):
BaseSession.__abstractmethods__ = self["__abstractmethods__"]
def test_init_api(self):
session = BaseSession()
session = CustomSession()
assert session.api == PRODUCTION
def test_init_custom_api(self):
@ -25,11 +29,11 @@ class TestBaseSession(DataMixin):
base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{",
)
session = BaseSession(api=api)
session = CustomSession(api=api)
assert session.api == api
def test_prepare_value(self):
session = BaseSession()
session = CustomSession()
now = datetime.datetime.now()
@ -41,7 +45,7 @@ class TestBaseSession(DataMixin):
assert session.prepare_value(42) == "42"
def test_clean_json(self):
session = BaseSession()
session = CustomSession()
cleaned_dict = session.clean_json({"key": "value", "null": None})
assert "key" in cleaned_dict
@ -54,7 +58,7 @@ class TestBaseSession(DataMixin):
assert cleaned_list[0] == "kaboom"
def test_clean_json_with_nested_json(self):
session = BaseSession()
session = CustomSession()
cleaned = session.clean_json(
{
@ -75,12 +79,12 @@ class TestBaseSession(DataMixin):
assert cleaned["nested_dict"] == {"key": "value"}
def test_clean_json_not_json(self):
session = BaseSession()
session = CustomSession()
assert session.clean_json(42) == 42
def test_raise_for_status(self):
session = BaseSession()
session = CustomSession()
session.raise_for_status(Response[bool](ok=True, result=True))
with pytest.raises(Exception):
@ -88,6 +92,19 @@ class TestBaseSession(DataMixin):
@pytest.mark.asyncio
async def test_make_request(self):
session = BaseSession()
session = CustomSession()
assert await session.make_request("TOKEN", GetMe()) is None
assert await session.make_request("42:TEST", GetMe()) is None
@pytest.mark.asyncio
async def test_context_manager(self):
session = CustomSession()
assert isinstance(session, AsyncContextManager)
with patch(
"tests.test_api.test_client.test_session.test_base_session.CustomSession.close",
new_callable=CoroutineMock,
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.awaited_once()

View file

@ -20,7 +20,7 @@ class TestDispatcher:
@pytest.mark.asyncio
async def test_feed_update(self):
dp = Dispatcher()
bot = Bot("TOKEN")
bot = Bot("42:TEST")
@dp.message_handler()
async def my_handler(message: Message, **kwargs):

View file

@ -0,0 +1,41 @@
from unittest.mock import patch
import pytest
from aiogram.utils.token import TokenValidationError, validate_token, extract_bot_id
BOT_ID = 123456789
VALID_TOKEN = '123456789:AABBCCDDEEFFaabbccddeeff-1234567890'
INVALID_TOKENS = [
'123456789:AABBCCDDEEFFaabbccddeeff 123456789', # space is exists
'ABC:AABBCCDDEEFFaabbccddeeff123456789', # left part is not digit
':AABBCCDDEEFFaabbccddeeff123456789', # there is no left part
'123456789:', # there is no right part
'ABC AABBCCDDEEFFaabbccddeeff123456789', # there is no ':' separator
None, # is None
12345678, # is digit
(42, 'TEST'), # is tuple
]
@pytest.fixture(params=INVALID_TOKENS, name='invalid_token')
def invalid_token_fixture(request):
return request.param
class TestCheckToken:
def test_valid(self):
assert validate_token(VALID_TOKEN) is True
def test_invalid_token(self, invalid_token):
with pytest.raises(TokenValidationError):
validate_token(invalid_token)
class TestExtractBotId:
def test_extract_bot_id(self):
with patch("aiogram.utils.token.validate_token") as mocked_validate_token:
result = extract_bot_id(VALID_TOKEN)
mocked_validate_token.assert_called_once_with(VALID_TOKEN)
assert result == BOT_ID