diff --git a/aiogram/api/client/session/aiohttp.py b/aiogram/api/client/session/aiohttp.py index c3c1bba5..774e7186 100644 --- a/aiogram/api/client/session/aiohttp.py +++ b/aiogram/api/client/session/aiohttp.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import ( Any, AsyncGenerator, - Callable, Dict, Iterable, List, @@ -15,11 +14,11 @@ from typing import ( cast, ) -from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector +from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector from aiogram.api.methods import Request, TelegramMethod -from .base import PRODUCTION, BaseSession, TelegramAPIServer +from .base import BaseSession T = TypeVar("T") _ProxyBasic = Union[str, Tuple[str, BasicAuth]] @@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"] class AiohttpSession(BaseSession): - def __init__( - self, - api: TelegramAPIServer = PRODUCTION, - json_loads: Optional[Callable[..., Any]] = None, - json_dumps: Optional[Callable[..., str]] = None, - proxy: Optional[_ProxyType] = None, - ): - super(AiohttpSession, self).__init__( - api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy - ) + def __init__(self, proxy: Optional[_ProxyType] = None): self._session: Optional[ClientSession] = None self._connector_type: Type[TCPConnector] = TCPConnector self._connector_init: Dict[str, Any] = {} + self._should_reset_connector = True # flag determines connector state + self._proxy: Optional[_ProxyType] = None - if self.proxy: + if proxy is not None: try: - self._connector_type, self._connector_init = _prepare_connector( - cast(_ProxyType, self.proxy) - ) + self._setup_proxy_connector(proxy) except ImportError as exc: # pragma: no cover raise UserWarning( "In order to use aiohttp client for proxy requests, install " "https://pypi.org/project/aiohttp-socks/" ) from exc + def _setup_proxy_connector(self, proxy: _ProxyType) -> None: + self._connector_type, self._connector_init = _prepare_connector(proxy) + self._proxy = proxy + + @property + def proxy(self) -> Optional[_ProxyType]: + return self._proxy + + @proxy.setter + def proxy(self, proxy: _ProxyType) -> None: + self._setup_proxy_connector(proxy) + self._should_reset_connector = True + async def create_session(self) -> ClientSession: + if self._should_reset_connector: + await self.close() + if self._session is None or self._session.closed: self._session = ClientSession(connector=self._connector_type(**self._connector_init)) + self._should_reset_connector = False return self._session @@ -125,7 +132,9 @@ class AiohttpSession(BaseSession): url = self.api.api_url(token=token, method=request.method) form = self.build_form_data(request) - async with session.post(url, data=form) as resp: + async with session.post( + url, data=form, timeout=call.request_timeout or self.timeout + ) as resp: raw_result = await resp.json(loads=self.json_loads) response = call.build_response(raw_result) @@ -136,9 +145,8 @@ class AiohttpSession(BaseSession): self, url: str, timeout: int, chunk_size: int ) -> AsyncGenerator[bytes, None]: session = await self.create_session() - client_timeout = ClientTimeout(total=timeout) - async with session.get(url, timeout=client_timeout) as resp: + async with session.get(url, timeout=timeout) as resp: async for chunk in resp.content.iter_chunked(chunk_size): yield chunk diff --git a/aiogram/api/client/session/base.py b/aiogram/api/client/session/base.py index 83e7f3ff..8213e4c3 100644 --- a/aiogram/api/client/session/base.py +++ b/aiogram/api/client/session/base.py @@ -4,7 +4,7 @@ import abc import datetime import json from types import TracebackType -from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union from aiogram.utils.exceptions import TelegramAPIError @@ -12,30 +12,58 @@ from ...methods import Response, TelegramMethod from ..telegram import PRODUCTION, TelegramAPIServer T = TypeVar("T") -PT = TypeVar("PT") +_JsonLoads = Callable[..., Any] +_JsonDumps = Callable[..., str] class BaseSession(abc.ABC): - def __init__( - self, - api: Optional[TelegramAPIServer] = None, - json_loads: Optional[Callable[..., Any]] = None, - json_dumps: Optional[Callable[..., str]] = None, - proxy: Optional[PT] = None, - ) -> None: - if api is None: - api = PRODUCTION - if json_loads is None: - json_loads = json.loads - if json_dumps is None: - json_dumps = json.dumps + # global session timeout + default_timeout: ClassVar[float] = 60.0 - self.api = api - self.json_loads = json_loads - self.json_dumps = json_dumps - self.proxy = proxy + _api: TelegramAPIServer + _json_loads: _JsonLoads + _json_dumps: _JsonDumps + _timeout: float - def raise_for_status(self, response: Response[T]) -> None: + @property + def api(self) -> TelegramAPIServer: + return getattr(self, "_api", PRODUCTION) # type: ignore + + @api.setter + def api(self, value: TelegramAPIServer) -> None: + self._api = value + + @property + def json_loads(self) -> _JsonLoads: + return getattr(self, "_json_loads", json.loads) # type: ignore + + @json_loads.setter + def json_loads(self, value: _JsonLoads) -> None: + self._json_loads = value # type: ignore + + @property + def json_dumps(self) -> _JsonDumps: + return getattr(self, "_json_dumps", json.dumps) # type: ignore + + @json_dumps.setter + def json_dumps(self, value: _JsonDumps) -> None: + self._json_dumps = value # type: ignore + + @property + def timeout(self) -> float: + return getattr(self, "_timeout", self.__class__.default_timeout) # type: ignore + + @timeout.setter + def timeout(self, value: float) -> None: + self._timeout = value + + @timeout.deleter + def timeout(self) -> None: + if hasattr(self, "_timeout"): + del self._timeout + + @classmethod + def raise_for_status(cls, response: Response[T]) -> None: if response.ok: return raise TelegramAPIError(response.description) diff --git a/aiogram/api/methods/base.py b/aiogram/api/methods/base.py index 52402977..8d15edba 100644 --- a/aiogram/api/methods/base.py +++ b/aiogram/api/methods/base.py @@ -55,6 +55,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]): def build_request(self) -> Request: # pragma: no cover pass + request_timeout: Optional[float] = None + + def dict(self, **kwargs: Any) -> Any: + # override dict of pydantic.BaseModel to overcome exporting request_timeout field + exclude = kwargs.pop("exclude", set()) + if isinstance(exclude, set): + exclude.add("request_timeout") + + return super().dict(exclude=exclude, **kwargs) + def build_response(self, data: Dict[str, Any]) -> Response[T]: # noinspection PyTypeChecker return Response[self.__returning__](**data) # type: ignore diff --git a/docs/api/client/session/aiohttp.md b/docs/api/client/session/aiohttp.md index 9eab5ede..223ad468 100644 --- a/docs/api/client/session/aiohttp.md +++ b/docs/api/client/session/aiohttp.md @@ -1,6 +1,6 @@ # Aiohttp session -AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp]('https://pypi.org/project/aiohttp/') +AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp](https://pypi.org/project/aiohttp/ "PyPi repository"){target=_blank} Currently `AiohttpSession` is a default session used in `aiogram.Bot` @@ -17,7 +17,7 @@ Bot('token', session=session) ## Proxy requests in AiohttpSession -In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks]('https://pypi.org/project/aiohttp-socks/') +In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks](https://pypi.org/project/aiohttp-socks/ "PyPi repository"){target=_blank} Binding session to bot: ```python diff --git a/mkdocs.yml b/mkdocs.yml index 0d77e640..527d0561 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ markdown_extensions: - pymdownx.inlinehilite - markdown_include.include: base_path: docs + - attr_list nav: - index.md diff --git a/tests/test_api/test_client/test_session/test_aiohttp_session.py b/tests/test_api/test_client/test_session/test_aiohttp_session.py index 2587a686..e7716f9a 100644 --- a/tests/test_api/test_client/test_session/test_aiohttp_session.py +++ b/tests/test_api/test_client/test_session/test_aiohttp_session.py @@ -83,6 +83,22 @@ class TestAiohttpSession: aiohttp_session = await session.create_session() assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector) + @pytest.mark.asyncio + async def test_reset_connector(self): + session = AiohttpSession() + assert session._should_reset_connector + await session.create_session() + assert session._should_reset_connector is False + await session.close() + assert session._should_reset_connector is False + + assert session.proxy is None + session.proxy = "socks5://auth:auth@proxy.url/" + assert session._should_reset_connector + await session.create_session() + assert session._should_reset_connector is False + await session.close() + @pytest.mark.asyncio async def test_close_session(self): session = AiohttpSession() diff --git a/tests/test_api/test_client/test_session/test_base_session.py b/tests/test_api/test_client/test_session/test_base_session.py index 4c86f9da..35dcfa8e 100644 --- a/tests/test_api/test_client/test_session/test_base_session.py +++ b/tests/test_api/test_client/test_session/test_base_session.py @@ -1,4 +1,5 @@ import datetime +import json from typing import AsyncContextManager, AsyncGenerator import pytest @@ -35,13 +36,56 @@ class TestBaseSession: session = CustomSession() assert session.api == PRODUCTION + def test_default_props(self): + session = CustomSession() + assert session.api == PRODUCTION + assert session.json_loads == json.loads + assert session.json_dumps == json.dumps + + def custom_loads(*_): + return json.loads + + def custom_dumps(*_): + return json.dumps + + session.json_dumps = custom_dumps + assert session.json_dumps == custom_dumps == session._json_dumps + session.json_loads = custom_loads + assert session.json_loads == custom_loads == session._json_loads + + different_session = CustomSession() + assert all( + not hasattr(different_session, attr) for attr in ("_json_loads", "_json_dumps", "_api") + ) + + def test_timeout(self): + session = CustomSession() + assert session.timeout == session.default_timeout == CustomSession.default_timeout + + session.default_timeout = float(65.0_0) # mypy will complain + assert session.timeout != session.default_timeout + + CustomSession.default_timeout = float(68.0_0) + assert session.timeout == CustomSession.default_timeout + + session.timeout = float(71.0_0) + assert session.timeout != session.default_timeout + del session.timeout + CustomSession.default_timeout = session.default_timeout + 100 + assert ( + session.timeout != BaseSession.default_timeout + and session.timeout == CustomSession.default_timeout + ) + def test_init_custom_api(self): api = TelegramAPIServer( base="http://example.com/{token}/{method}", - file="http://example.com/{token}/file/{path{", + file="http://example.com/{token}/file/{path}", ) - session = CustomSession(api=api) + session = CustomSession() + session.api = api assert session.api == api + assert "example.com" in session.api.base def test_prepare_value(self): session = CustomSession()