From c674b5547b744c9a595faea66be315975d89ae08 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Thu, 28 Nov 2019 23:12:44 +0200 Subject: [PATCH] Add token validation util, fix deepcopy of sessions and make bot hashable and comparable --- aiogram/api/client/base.py | 25 +++++++-- aiogram/api/client/session/aiohttp.py | 19 ++++++- aiogram/api/client/session/base.py | 20 ++++--- aiogram/utils/token.py | 42 +++++++++++++++ tests/mocked_bot.py | 2 +- tests/test_api/test_client/test_api_server.py | 8 +-- tests/test_api/test_client/test_base_bot.py | 21 ++++++-- .../test_session/test_aiohttp_session.py | 31 ++++++++++- .../test_session/test_base_session.py | 53 ++++++++++++------- tests/test_dispatcher/test_dispatcher.py | 2 +- tests/test_utils/test_token.py | 41 ++++++++++++++ 11 files changed, 223 insertions(+), 41 deletions(-) create mode 100644 aiogram/utils/token.py create mode 100644 tests/test_utils/test_token.py diff --git a/aiogram/api/client/base.py b/aiogram/api/client/base.py index 5ded4c45..0e2516a0 100644 --- a/aiogram/api/client/base.py +++ b/aiogram/api/client/base.py @@ -1,6 +1,10 @@ -from typing import TypeVar +from __future__ import annotations + +import copy +from typing import TypeVar, Dict, Any from ...utils.mixins import ContextInstanceMixin +from ...utils.token import extract_bot_id, validate_token from ..methods import TelegramMethod from .session.aiohttp import AiohttpSession from .session.base import BaseSession @@ -10,13 +14,20 @@ T = TypeVar("T") class BaseBot(ContextInstanceMixin): def __init__(self, token: str, session: BaseSession = None): + validate_token(token) + if session is None: session = AiohttpSession() + self.session = session - self.token = token + self.__token = token + + @property + def id(self): + return extract_bot_id(self.__token) async def emit(self, method: TelegramMethod[T]) -> T: - return await self.session.make_request(self.token, method) + return await self.session.make_request(self.__token, method) async def close(self): await self.session.close() @@ -26,3 +37,11 @@ class BaseBot(ContextInstanceMixin): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.session.close() + + def __hash__(self): + return hash(self.__token) + + def __eq__(self, other: BaseBot): + if not isinstance(other, BaseBot): + return False + return hash(self) == hash(other) diff --git a/aiogram/api/client/session/aiohttp.py b/aiogram/api/client/session/aiohttp.py index d9f65b2b..442a8bd8 100644 --- a/aiogram/api/client/session/aiohttp.py +++ b/aiogram/api/client/session/aiohttp.py @@ -1,4 +1,7 @@ -from typing import Callable, Optional, TypeVar, cast +from __future__ import annotations + +import copy +from typing import Callable, Optional, TypeVar, cast, Dict, Any from aiohttp import ClientSession, FormData @@ -53,3 +56,17 @@ class AiohttpSession(BaseSession): response = call.build_response(raw_result) self.raise_for_status(response) return cast(T, response.result) + + async def __aenter__(self) -> AiohttpSession: + await self.create_session() + return self + + def __deepcopy__(self, memodict: Dict[str, Any]): + cls = self.__class__ + result = cls.__new__(cls) + memodict[id(self)] = result + for key, value in self.__dict__.items(): + # aiohttp ClientSession cannot be copied. + copied_value = copy.deepcopy(value, memo=memodict) if key != '_session' else None + setattr(result, key, copied_value) + return result diff --git a/aiogram/api/client/session/base.py b/aiogram/api/client/session/base.py index 3999d9cd..4f960b23 100644 --- a/aiogram/api/client/session/base.py +++ b/aiogram/api/client/session/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import datetime import json @@ -13,9 +15,9 @@ class BaseSession(abc.ABC): def __init__( self, api: Optional[TelegramAPIServer] = None, - json_loads: Optional[Callable] = None, - json_dumps: Optional[Callable] = None, - ): + json_loads: Optional[Callable[[Any], Any]] = None, + json_dumps: Optional[Callable[[Any], Any]] = None, + ) -> None: if api is None: api = PRODUCTION if json_loads is None: @@ -27,7 +29,7 @@ class BaseSession(abc.ABC): self.json_loads = json_loads self.json_dumps = json_dumps - def raise_for_status(self, response: Response[T]): + def raise_for_status(self, response: Response[T]) -> None: if response.ok: return raise Exception(response.description) @@ -37,7 +39,7 @@ class BaseSession(abc.ABC): pass @abc.abstractmethod - async def make_request(self, token: str, method: TelegramMethod[T]) -> T: + async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover pass def prepare_value(self, value: Any) -> Union[str, int, bool]: @@ -53,9 +55,15 @@ class BaseSession(abc.ABC): else: return str(value) - def clean_json(self, value: Any): + def clean_json(self, value: Any) -> Any: if isinstance(value, list): return [self.clean_json(v) for v in value if v is not None] elif isinstance(value, dict): return {k: self.clean_json(v) for k, v in value.items() if v is not None} return value + + async def __aenter__(self) -> BaseSession: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() diff --git a/aiogram/utils/token.py b/aiogram/utils/token.py new file mode 100644 index 00000000..c0738467 --- /dev/null +++ b/aiogram/utils/token.py @@ -0,0 +1,42 @@ +from functools import lru_cache + + +class TokenValidationError(Exception): + pass + + +@lru_cache() +def validate_token(token: str) -> bool: + """ + Validate Telegram token + + :param token: + :return: + """ + if not isinstance(token, str): + raise TokenValidationError( + f"Token is invalid! It must be 'str' type instead of {type(token)} type." + ) + + if any(x.isspace() for x in token): + message = "Token is invalid! It can't contains spaces." + raise TokenValidationError(message) + + left, sep, right = token.partition(":") + if (not sep) or (not left.isdigit()) or (not right): + raise TokenValidationError("Token is invalid!") + + return True + + +@lru_cache() +def extract_bot_id(token: str) -> int: + """ + Extract bot ID from Telegram token + + :param token: + :return: + """ + validate_token(token) + raw_bot_id, *_ = token.split(":") + return int(raw_bot_id) diff --git a/tests/mocked_bot.py b/tests/mocked_bot.py index 77da88bc..f7da0971 100644 --- a/tests/mocked_bot.py +++ b/tests/mocked_bot.py @@ -35,7 +35,7 @@ class MockedBot(Bot): session: MockedSession def __init__(self): - super(MockedBot, self).__init__("TOKEN", session=MockedSession()) + super(MockedBot, self).__init__("42:TEST", session=MockedSession()) def add_result_for( self, diff --git a/tests/test_api/test_client/test_api_server.py b/tests/test_api/test_client/test_api_server.py index 3102568f..74b6a785 100644 --- a/tests/test_api/test_client/test_api_server.py +++ b/tests/test_api/test_client/test_api_server.py @@ -3,9 +3,9 @@ from aiogram.api.client.telegram import PRODUCTION class TestAPIServer: def test_method_url(self): - method_url = PRODUCTION.api_url(token="TOKEN", method="apiMethod") - assert method_url == "https://api.telegram.org/botTOKEN/apiMethod" + method_url = PRODUCTION.api_url(token="42:TEST", method="apiMethod") + assert method_url == "https://api.telegram.org/bot42:TEST/apiMethod" def test_file_url(self): - file_url = PRODUCTION.file_url(token="TOKEN", path="path") - assert file_url == "https://api.telegram.org/file/botTOKEN/path" + file_url = PRODUCTION.file_url(token="42:TEST", path="path") + assert file_url == "https://api.telegram.org/file/bot42:TEST/path" diff --git a/tests/test_api/test_client/test_base_bot.py b/tests/test_api/test_client/test_base_bot.py index 04ea8540..7e8f2a85 100644 --- a/tests/test_api/test_client/test_base_bot.py +++ b/tests/test_api/test_client/test_base_bot.py @@ -1,3 +1,5 @@ +import copy + import pytest from asynctest import CoroutineMock, patch @@ -9,12 +11,21 @@ from aiogram.api.methods import GetMe class TestBaseBot: def test_init(self): - base_bot = BaseBot("TOKEN") + base_bot = BaseBot("42:TEST") assert isinstance(base_bot.session, AiohttpSession) + assert base_bot.id == 42 + + def test_hashable(self): + base_bot = BaseBot("42:TEST") + assert hash(base_bot) == hash("42:TEST") + + def test_equals(self): + base_bot = BaseBot("42:TEST") + assert base_bot == BaseBot("42:TEST") @pytest.mark.asyncio async def test_emit(self): - base_bot = BaseBot("TOKEN") + base_bot = BaseBot("42:TEST") method = GetMe() @@ -23,11 +34,11 @@ class TestBaseBot: new_callable=CoroutineMock, ) as mocked_make_request: await base_bot.emit(method) - mocked_make_request.assert_awaited_with("TOKEN", method) + mocked_make_request.assert_awaited_with("42:TEST", method) @pytest.mark.asyncio async def test_close(self): - base_bot = BaseBot("TOKEN", session=AiohttpSession()) + base_bot = BaseBot("42:TEST", session=AiohttpSession()) await base_bot.session.create_session() with patch( @@ -41,6 +52,6 @@ class TestBaseBot: with patch( "aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock ) as mocked_close: - async with BaseBot("TOKEN", session=AiohttpSession()) as bot: + async with BaseBot("42:TEST", session=AiohttpSession()) as bot: assert isinstance(bot, BaseBot) mocked_close.assert_awaited() diff --git a/tests/test_api/test_client/test_session/test_aiohttp_session.py b/tests/test_api/test_client/test_session/test_aiohttp_session.py index 35284e62..85d9e67c 100644 --- a/tests/test_api/test_client/test_session/test_aiohttp_session.py +++ b/tests/test_api/test_client/test_session/test_aiohttp_session.py @@ -1,3 +1,6 @@ +import copy +from typing import AsyncContextManager + import aiohttp import pytest from aresponses import ResponsesMockServer @@ -74,7 +77,7 @@ class TestAiohttpSession: async def test_make_request(self, aresponses: ResponsesMockServer): aresponses.add( aresponses.ANY, - "/botTOKEN/method", + "/bot42:TEST/method", "post", aresponses.Response( status=200, @@ -95,8 +98,32 @@ class TestAiohttpSession: with patch( "aiogram.api.client.session.base.BaseSession.raise_for_status" ) as patched_raise_for_status: - result = await session.make_request("TOKEN", call) + result = await session.make_request("42:TEST", call) assert isinstance(result, int) assert result == 42 assert patched_raise_for_status.called_once() + + @pytest.mark.asyncio + async def test_context_manager(self): + session = AiohttpSession() + assert isinstance(session, AsyncContextManager) + + with patch( + "aiogram.api.client.session.aiohttp.AiohttpSession.create_session", + new_callable=CoroutineMock, + ) as mocked_create_session, patch( + "aiogram.api.client.session.aiohttp.AiohttpSession.close", new_callable=CoroutineMock + ) as mocked_close: + async with session as ctx: + assert session == ctx + mocked_close.awaited_once() + mocked_create_session.awaited_once() + + @pytest.mark.asyncio + async def test_deepcopy(self): + # Session should be copied without aiohttp.ClientSession + async with AiohttpSession() as session: + cloned_session = copy.deepcopy(session) + assert cloned_session != session + assert cloned_session._session is None diff --git a/tests/test_api/test_client/test_session/test_base_session.py b/tests/test_api/test_client/test_session/test_base_session.py index 94bb92f4..91cb5508 100644 --- a/tests/test_api/test_client/test_session/test_base_session.py +++ b/tests/test_api/test_client/test_session/test_base_session.py @@ -1,23 +1,27 @@ import datetime +from typing import AsyncContextManager import pytest +from asynctest import CoroutineMock, patch -from aiogram.api.client.session.base import BaseSession +from aiogram.api.client.session.base import BaseSession, T from aiogram.api.client.telegram import PRODUCTION, TelegramAPIServer -from aiogram.api.methods import GetMe, Response +from aiogram.api.methods import GetMe, Response, TelegramMethod from aiogram.utils.mixins import DataMixin +class CustomSession(BaseSession): + async def close(self): + pass + + async def make_request(self, token: str, method: TelegramMethod[T]) -> None: # type: ignore + assert isinstance(token, str) + assert isinstance(method, TelegramMethod) + + class TestBaseSession(DataMixin): - def setup(self): - self["__abstractmethods__"] = BaseSession.__abstractmethods__ - BaseSession.__abstractmethods__ = set() - - def teardown(self): - BaseSession.__abstractmethods__ = self["__abstractmethods__"] - def test_init_api(self): - session = BaseSession() + session = CustomSession() assert session.api == PRODUCTION def test_init_custom_api(self): @@ -25,11 +29,11 @@ class TestBaseSession(DataMixin): base="http://example.com/{token}/{method}", file="http://example.com/{token}/file/{path{", ) - session = BaseSession(api=api) + session = CustomSession(api=api) assert session.api == api def test_prepare_value(self): - session = BaseSession() + session = CustomSession() now = datetime.datetime.now() @@ -41,7 +45,7 @@ class TestBaseSession(DataMixin): assert session.prepare_value(42) == "42" def test_clean_json(self): - session = BaseSession() + session = CustomSession() cleaned_dict = session.clean_json({"key": "value", "null": None}) assert "key" in cleaned_dict @@ -54,7 +58,7 @@ class TestBaseSession(DataMixin): assert cleaned_list[0] == "kaboom" def test_clean_json_with_nested_json(self): - session = BaseSession() + session = CustomSession() cleaned = session.clean_json( { @@ -75,12 +79,12 @@ class TestBaseSession(DataMixin): assert cleaned["nested_dict"] == {"key": "value"} def test_clean_json_not_json(self): - session = BaseSession() + session = CustomSession() assert session.clean_json(42) == 42 def test_raise_for_status(self): - session = BaseSession() + session = CustomSession() session.raise_for_status(Response[bool](ok=True, result=True)) with pytest.raises(Exception): @@ -88,6 +92,19 @@ class TestBaseSession(DataMixin): @pytest.mark.asyncio async def test_make_request(self): - session = BaseSession() + session = CustomSession() - assert await session.make_request("TOKEN", GetMe()) is None + assert await session.make_request("42:TEST", GetMe()) is None + + @pytest.mark.asyncio + async def test_context_manager(self): + session = CustomSession() + assert isinstance(session, AsyncContextManager) + + with patch( + "tests.test_api.test_client.test_session.test_base_session.CustomSession.close", + new_callable=CoroutineMock, + ) as mocked_close: + async with session as ctx: + assert session == ctx + mocked_close.awaited_once() diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 8a9ecec9..bc590b69 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -20,7 +20,7 @@ class TestDispatcher: @pytest.mark.asyncio async def test_feed_update(self): dp = Dispatcher() - bot = Bot("TOKEN") + bot = Bot("42:TEST") @dp.message_handler() async def my_handler(message: Message, **kwargs): diff --git a/tests/test_utils/test_token.py b/tests/test_utils/test_token.py new file mode 100644 index 00000000..9b4f419e --- /dev/null +++ b/tests/test_utils/test_token.py @@ -0,0 +1,41 @@ +from unittest.mock import patch + +import pytest + +from aiogram.utils.token import TokenValidationError, validate_token, extract_bot_id + +BOT_ID = 123456789 +VALID_TOKEN = '123456789:AABBCCDDEEFFaabbccddeeff-1234567890' +INVALID_TOKENS = [ + '123456789:AABBCCDDEEFFaabbccddeeff 123456789', # space is exists + 'ABC:AABBCCDDEEFFaabbccddeeff123456789', # left part is not digit + ':AABBCCDDEEFFaabbccddeeff123456789', # there is no left part + '123456789:', # there is no right part + 'ABC AABBCCDDEEFFaabbccddeeff123456789', # there is no ':' separator + None, # is None + 12345678, # is digit + (42, 'TEST'), # is tuple +] + + +@pytest.fixture(params=INVALID_TOKENS, name='invalid_token') +def invalid_token_fixture(request): + return request.param + + +class TestCheckToken: + def test_valid(self): + assert validate_token(VALID_TOKEN) is True + + def test_invalid_token(self, invalid_token): + with pytest.raises(TokenValidationError): + validate_token(invalid_token) + + +class TestExtractBotId: + def test_extract_bot_id(self): + with patch("aiogram.utils.token.validate_token") as mocked_validate_token: + result = extract_bot_id(VALID_TOKEN) + + mocked_validate_token.assert_called_once_with(VALID_TOKEN) + assert result == BOT_ID