Add token validation util, fix deepcopy of sessions and make bot hashable and comparable

This commit is contained in:
Alex Root Junior 2019-11-28 23:12:44 +02:00
parent 9adc2f91bd
commit c674b5547b
11 changed files with 223 additions and 41 deletions

View file

@ -1,6 +1,10 @@
from typing import TypeVar from __future__ import annotations
import copy
from typing import TypeVar, Dict, Any
from ...utils.mixins import ContextInstanceMixin from ...utils.mixins import ContextInstanceMixin
from ...utils.token import extract_bot_id, validate_token
from ..methods import TelegramMethod from ..methods import TelegramMethod
from .session.aiohttp import AiohttpSession from .session.aiohttp import AiohttpSession
from .session.base import BaseSession from .session.base import BaseSession
@ -10,13 +14,20 @@ T = TypeVar("T")
class BaseBot(ContextInstanceMixin): class BaseBot(ContextInstanceMixin):
def __init__(self, token: str, session: BaseSession = None): def __init__(self, token: str, session: BaseSession = None):
validate_token(token)
if session is None: if session is None:
session = AiohttpSession() session = AiohttpSession()
self.session = session self.session = session
self.token = token self.__token = token
@property
def id(self):
return extract_bot_id(self.__token)
async def emit(self, method: TelegramMethod[T]) -> T: async def emit(self, method: TelegramMethod[T]) -> T:
return await self.session.make_request(self.token, method) return await self.session.make_request(self.__token, method)
async def close(self): async def close(self):
await self.session.close() await self.session.close()
@ -26,3 +37,11 @@ class BaseBot(ContextInstanceMixin):
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close() await self.session.close()
def __hash__(self):
return hash(self.__token)
def __eq__(self, other: BaseBot):
if not isinstance(other, BaseBot):
return False
return hash(self) == hash(other)

View file

@ -1,4 +1,7 @@
from typing import Callable, Optional, TypeVar, cast from __future__ import annotations
import copy
from typing import Callable, Optional, TypeVar, cast, Dict, Any
from aiohttp import ClientSession, FormData from aiohttp import ClientSession, FormData
@ -53,3 +56,17 @@ class AiohttpSession(BaseSession):
response = call.build_response(raw_result) response = call.build_response(raw_result)
self.raise_for_status(response) self.raise_for_status(response)
return cast(T, response.result) return cast(T, response.result)
async def __aenter__(self) -> AiohttpSession:
await self.create_session()
return self
def __deepcopy__(self, memodict: Dict[str, Any]):
cls = self.__class__
result = cls.__new__(cls)
memodict[id(self)] = result
for key, value in self.__dict__.items():
# aiohttp ClientSession cannot be copied.
copied_value = copy.deepcopy(value, memo=memodict) if key != '_session' else None
setattr(result, key, copied_value)
return result

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import abc import abc
import datetime import datetime
import json import json
@ -13,9 +15,9 @@ class BaseSession(abc.ABC):
def __init__( def __init__(
self, self,
api: Optional[TelegramAPIServer] = None, api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable] = None, json_loads: Optional[Callable[[Any], Any]] = None,
json_dumps: Optional[Callable] = None, json_dumps: Optional[Callable[[Any], Any]] = None,
): ) -> None:
if api is None: if api is None:
api = PRODUCTION api = PRODUCTION
if json_loads is None: if json_loads is None:
@ -27,7 +29,7 @@ class BaseSession(abc.ABC):
self.json_loads = json_loads self.json_loads = json_loads
self.json_dumps = json_dumps self.json_dumps = json_dumps
def raise_for_status(self, response: Response[T]): def raise_for_status(self, response: Response[T]) -> None:
if response.ok: if response.ok:
return return
raise Exception(response.description) raise Exception(response.description)
@ -37,7 +39,7 @@ class BaseSession(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
pass pass
def prepare_value(self, value: Any) -> Union[str, int, bool]: def prepare_value(self, value: Any) -> Union[str, int, bool]:
@ -53,9 +55,15 @@ class BaseSession(abc.ABC):
else: else:
return str(value) return str(value)
def clean_json(self, value: Any): def clean_json(self, value: Any) -> Any:
if isinstance(value, list): if isinstance(value, list):
return [self.clean_json(v) for v in value if v is not None] return [self.clean_json(v) for v in value if v is not None]
elif isinstance(value, dict): elif isinstance(value, dict):
return {k: self.clean_json(v) for k, v in value.items() if v is not None} return {k: self.clean_json(v) for k, v in value.items() if v is not None}
return value return value
async def __aenter__(self) -> BaseSession:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

42
aiogram/utils/token.py Normal file
View file

@ -0,0 +1,42 @@
from functools import lru_cache
class TokenValidationError(Exception):
pass
@lru_cache()
def validate_token(token: str) -> bool:
"""
Validate Telegram token
:param token:
:return:
"""
if not isinstance(token, str):
raise TokenValidationError(
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
)
if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces."
raise TokenValidationError(message)
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right):
raise TokenValidationError("Token is invalid!")
return True
@lru_cache()
def extract_bot_id(token: str) -> int:
"""
Extract bot ID from Telegram token
:param token:
:return:
"""
validate_token(token)
raw_bot_id, *_ = token.split(":")
return int(raw_bot_id)

View file

@ -35,7 +35,7 @@ class MockedBot(Bot):
session: MockedSession session: MockedSession
def __init__(self): def __init__(self):
super(MockedBot, self).__init__("TOKEN", session=MockedSession()) super(MockedBot, self).__init__("42:TEST", session=MockedSession())
def add_result_for( def add_result_for(
self, self,

View file

@ -3,9 +3,9 @@ from aiogram.api.client.telegram import PRODUCTION
class TestAPIServer: class TestAPIServer:
def test_method_url(self): def test_method_url(self):
method_url = PRODUCTION.api_url(token="TOKEN", method="apiMethod") method_url = PRODUCTION.api_url(token="42:TEST", method="apiMethod")
assert method_url == "https://api.telegram.org/botTOKEN/apiMethod" assert method_url == "https://api.telegram.org/bot42:TEST/apiMethod"
def test_file_url(self): def test_file_url(self):
file_url = PRODUCTION.file_url(token="TOKEN", path="path") file_url = PRODUCTION.file_url(token="42:TEST", path="path")
assert file_url == "https://api.telegram.org/file/botTOKEN/path" assert file_url == "https://api.telegram.org/file/bot42:TEST/path"

View file

@ -1,3 +1,5 @@
import copy
import pytest import pytest
from asynctest import CoroutineMock, patch from asynctest import CoroutineMock, patch
@ -9,12 +11,21 @@ from aiogram.api.methods import GetMe
class TestBaseBot: class TestBaseBot:
def test_init(self): def test_init(self):
base_bot = BaseBot("TOKEN") base_bot = BaseBot("42:TEST")
assert isinstance(base_bot.session, AiohttpSession) assert isinstance(base_bot.session, AiohttpSession)
assert base_bot.id == 42
def test_hashable(self):
base_bot = BaseBot("42:TEST")
assert hash(base_bot) == hash("42:TEST")
def test_equals(self):
base_bot = BaseBot("42:TEST")
assert base_bot == BaseBot("42:TEST")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_emit(self): async def test_emit(self):
base_bot = BaseBot("TOKEN") base_bot = BaseBot("42:TEST")
method = GetMe() method = GetMe()
@ -23,11 +34,11 @@ class TestBaseBot:
new_callable=CoroutineMock, new_callable=CoroutineMock,
) as mocked_make_request: ) as mocked_make_request:
await base_bot.emit(method) await base_bot.emit(method)
mocked_make_request.assert_awaited_with("TOKEN", method) mocked_make_request.assert_awaited_with("42:TEST", method)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close(self): async def test_close(self):
base_bot = BaseBot("TOKEN", session=AiohttpSession()) base_bot = BaseBot("42:TEST", session=AiohttpSession())
await base_bot.session.create_session() await base_bot.session.create_session()
with patch( with patch(
@ -41,6 +52,6 @@ class TestBaseBot:
with patch( with patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock "aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
) as mocked_close: ) as mocked_close:
async with BaseBot("TOKEN", session=AiohttpSession()) as bot: async with BaseBot("42:TEST", session=AiohttpSession()) as bot:
assert isinstance(bot, BaseBot) assert isinstance(bot, BaseBot)
mocked_close.assert_awaited() mocked_close.assert_awaited()

View file

@ -1,3 +1,6 @@
import copy
from typing import AsyncContextManager
import aiohttp import aiohttp
import pytest import pytest
from aresponses import ResponsesMockServer from aresponses import ResponsesMockServer
@ -74,7 +77,7 @@ class TestAiohttpSession:
async def test_make_request(self, aresponses: ResponsesMockServer): async def test_make_request(self, aresponses: ResponsesMockServer):
aresponses.add( aresponses.add(
aresponses.ANY, aresponses.ANY,
"/botTOKEN/method", "/bot42:TEST/method",
"post", "post",
aresponses.Response( aresponses.Response(
status=200, status=200,
@ -95,8 +98,32 @@ class TestAiohttpSession:
with patch( with patch(
"aiogram.api.client.session.base.BaseSession.raise_for_status" "aiogram.api.client.session.base.BaseSession.raise_for_status"
) as patched_raise_for_status: ) as patched_raise_for_status:
result = await session.make_request("TOKEN", call) result = await session.make_request("42:TEST", call)
assert isinstance(result, int) assert isinstance(result, int)
assert result == 42 assert result == 42
assert patched_raise_for_status.called_once() assert patched_raise_for_status.called_once()
@pytest.mark.asyncio
async def test_context_manager(self):
session = AiohttpSession()
assert isinstance(session, AsyncContextManager)
with patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.create_session",
new_callable=CoroutineMock,
) as mocked_create_session, patch(
"aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.awaited_once()
mocked_create_session.awaited_once()
@pytest.mark.asyncio
async def test_deepcopy(self):
# Session should be copied without aiohttp.ClientSession
async with AiohttpSession() as session:
cloned_session = copy.deepcopy(session)
assert cloned_session != session
assert cloned_session._session is None

View file

@ -1,23 +1,27 @@
import datetime import datetime
from typing import AsyncContextManager
import pytest import pytest
from asynctest import CoroutineMock, patch
from aiogram.api.client.session.base import BaseSession from aiogram.api.client.session.base import BaseSession, T
from aiogram.api.client.telegram import PRODUCTION, TelegramAPIServer from aiogram.api.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.api.methods import GetMe, Response from aiogram.api.methods import GetMe, Response, TelegramMethod
from aiogram.utils.mixins import DataMixin from aiogram.utils.mixins import DataMixin
class CustomSession(BaseSession):
async def close(self):
pass
async def make_request(self, token: str, method: TelegramMethod[T]) -> None: # type: ignore
assert isinstance(token, str)
assert isinstance(method, TelegramMethod)
class TestBaseSession(DataMixin): class TestBaseSession(DataMixin):
def setup(self):
self["__abstractmethods__"] = BaseSession.__abstractmethods__
BaseSession.__abstractmethods__ = set()
def teardown(self):
BaseSession.__abstractmethods__ = self["__abstractmethods__"]
def test_init_api(self): def test_init_api(self):
session = BaseSession() session = CustomSession()
assert session.api == PRODUCTION assert session.api == PRODUCTION
def test_init_custom_api(self): def test_init_custom_api(self):
@ -25,11 +29,11 @@ class TestBaseSession(DataMixin):
base="http://example.com/{token}/{method}", base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{", file="http://example.com/{token}/file/{path{",
) )
session = BaseSession(api=api) session = CustomSession(api=api)
assert session.api == api assert session.api == api
def test_prepare_value(self): def test_prepare_value(self):
session = BaseSession() session = CustomSession()
now = datetime.datetime.now() now = datetime.datetime.now()
@ -41,7 +45,7 @@ class TestBaseSession(DataMixin):
assert session.prepare_value(42) == "42" assert session.prepare_value(42) == "42"
def test_clean_json(self): def test_clean_json(self):
session = BaseSession() session = CustomSession()
cleaned_dict = session.clean_json({"key": "value", "null": None}) cleaned_dict = session.clean_json({"key": "value", "null": None})
assert "key" in cleaned_dict assert "key" in cleaned_dict
@ -54,7 +58,7 @@ class TestBaseSession(DataMixin):
assert cleaned_list[0] == "kaboom" assert cleaned_list[0] == "kaboom"
def test_clean_json_with_nested_json(self): def test_clean_json_with_nested_json(self):
session = BaseSession() session = CustomSession()
cleaned = session.clean_json( cleaned = session.clean_json(
{ {
@ -75,12 +79,12 @@ class TestBaseSession(DataMixin):
assert cleaned["nested_dict"] == {"key": "value"} assert cleaned["nested_dict"] == {"key": "value"}
def test_clean_json_not_json(self): def test_clean_json_not_json(self):
session = BaseSession() session = CustomSession()
assert session.clean_json(42) == 42 assert session.clean_json(42) == 42
def test_raise_for_status(self): def test_raise_for_status(self):
session = BaseSession() session = CustomSession()
session.raise_for_status(Response[bool](ok=True, result=True)) session.raise_for_status(Response[bool](ok=True, result=True))
with pytest.raises(Exception): with pytest.raises(Exception):
@ -88,6 +92,19 @@ class TestBaseSession(DataMixin):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_make_request(self): async def test_make_request(self):
session = BaseSession() session = CustomSession()
assert await session.make_request("TOKEN", GetMe()) is None assert await session.make_request("42:TEST", GetMe()) is None
@pytest.mark.asyncio
async def test_context_manager(self):
session = CustomSession()
assert isinstance(session, AsyncContextManager)
with patch(
"tests.test_api.test_client.test_session.test_base_session.CustomSession.close",
new_callable=CoroutineMock,
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.awaited_once()

View file

@ -20,7 +20,7 @@ class TestDispatcher:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_feed_update(self): async def test_feed_update(self):
dp = Dispatcher() dp = Dispatcher()
bot = Bot("TOKEN") bot = Bot("42:TEST")
@dp.message_handler() @dp.message_handler()
async def my_handler(message: Message, **kwargs): async def my_handler(message: Message, **kwargs):

View file

@ -0,0 +1,41 @@
from unittest.mock import patch
import pytest
from aiogram.utils.token import TokenValidationError, validate_token, extract_bot_id
BOT_ID = 123456789
VALID_TOKEN = '123456789:AABBCCDDEEFFaabbccddeeff-1234567890'
INVALID_TOKENS = [
'123456789:AABBCCDDEEFFaabbccddeeff 123456789', # space is exists
'ABC:AABBCCDDEEFFaabbccddeeff123456789', # left part is not digit
':AABBCCDDEEFFaabbccddeeff123456789', # there is no left part
'123456789:', # there is no right part
'ABC AABBCCDDEEFFaabbccddeeff123456789', # there is no ':' separator
None, # is None
12345678, # is digit
(42, 'TEST'), # is tuple
]
@pytest.fixture(params=INVALID_TOKENS, name='invalid_token')
def invalid_token_fixture(request):
return request.param
class TestCheckToken:
def test_valid(self):
assert validate_token(VALID_TOKEN) is True
def test_invalid_token(self, invalid_token):
with pytest.raises(TokenValidationError):
validate_token(invalid_token)
class TestExtractBotId:
def test_extract_bot_id(self):
with patch("aiogram.utils.token.validate_token") as mocked_validate_token:
result = extract_bot_id(VALID_TOKEN)
mocked_validate_token.assert_called_once_with(VALID_TOKEN)
assert result == BOT_ID