diff --git a/aiogram/api/client/session/aiohttp.py b/aiogram/api/client/session/aiohttp.py index c3c1bba5..be5861ce 100644 --- a/aiogram/api/client/session/aiohttp.py +++ b/aiogram/api/client/session/aiohttp.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import ( Any, AsyncGenerator, - Callable, Dict, Iterable, List, @@ -15,11 +14,11 @@ from typing import ( cast, ) -from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector +from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector from aiogram.api.methods import Request, TelegramMethod -from .base import PRODUCTION, BaseSession, TelegramAPIServer +from .base import BaseSession T = TypeVar("T") _ProxyBasic = Union[str, Tuple[str, BasicAuth]] @@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"] class AiohttpSession(BaseSession): - def __init__( - self, - api: TelegramAPIServer = PRODUCTION, - json_loads: Optional[Callable[..., Any]] = None, - json_dumps: Optional[Callable[..., str]] = None, - proxy: Optional[_ProxyType] = None, - ): - super(AiohttpSession, self).__init__( - api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy - ) + def __init__(self, proxy: Optional[_ProxyType] = None): self._session: Optional[ClientSession] = None self._connector_type: Type[TCPConnector] = TCPConnector self._connector_init: Dict[str, Any] = {} + self._should_reset_connector = True # flag determines connector state + self._proxy: Optional[_ProxyType] = None - if self.proxy: + if proxy is not None: try: - self._connector_type, self._connector_init = _prepare_connector( - cast(_ProxyType, self.proxy) - ) + self._setup_proxy_connector(proxy) except ImportError as exc: # pragma: no cover raise UserWarning( "In order to use aiohttp client for proxy requests, install " "https://pypi.org/project/aiohttp-socks/" ) from exc + def _setup_proxy_connector(self, proxy: _ProxyType) -> None: + self._connector_type, self._connector_init = _prepare_connector(proxy) + self._proxy = proxy + + @property + def proxy(self) -> Optional[_ProxyType]: + return self._proxy + + @proxy.setter + def proxy(self, proxy: _ProxyType) -> None: + self._setup_proxy_connector(proxy) + self._should_reset_connector = True + async def create_session(self) -> ClientSession: + if self._should_reset_connector: + await self.close() + if self._session is None or self._session.closed: self._session = ClientSession(connector=self._connector_type(**self._connector_init)) + self._should_reset_connector = False return self._session @@ -125,7 +132,7 @@ class AiohttpSession(BaseSession): url = self.api.api_url(token=token, method=request.method) form = self.build_form_data(request) - async with session.post(url, data=form) as resp: + async with session.post(url, data=form, timeout=call.request_timeout) as resp: raw_result = await resp.json(loads=self.json_loads) response = call.build_response(raw_result) @@ -136,9 +143,8 @@ class AiohttpSession(BaseSession): self, url: str, timeout: int, chunk_size: int ) -> AsyncGenerator[bytes, None]: session = await self.create_session() - client_timeout = ClientTimeout(total=timeout) - async with session.get(url, timeout=client_timeout) as resp: + async with session.get(url, timeout=timeout) as resp: async for chunk in resp.content.iter_chunked(chunk_size): yield chunk diff --git a/aiogram/api/client/session/base.py b/aiogram/api/client/session/base.py index 83e7f3ff..450b958d 100644 --- a/aiogram/api/client/session/base.py +++ b/aiogram/api/client/session/base.py @@ -16,26 +16,42 @@ PT = TypeVar("PT") class BaseSession(abc.ABC): - def __init__( - self, - api: Optional[TelegramAPIServer] = None, - json_loads: Optional[Callable[..., Any]] = None, - json_dumps: Optional[Callable[..., str]] = None, - proxy: Optional[PT] = None, - ) -> None: - if api is None: - api = PRODUCTION - if json_loads is None: - json_loads = json.loads - if json_dumps is None: - json_dumps = json.dumps + _api: TelegramAPIServer + _json_loads: Callable[..., Any] + _json_dumps: Callable[..., str] - self.api = api - self.json_loads = json_loads - self.json_dumps = json_dumps - self.proxy = proxy + @property + def api(self) -> TelegramAPIServer: # pragma: no cover + if not hasattr(self, "_api"): + return PRODUCTION + return self._api - def raise_for_status(self, response: Response[T]) -> None: + @api.setter + def api(self, value: TelegramAPIServer) -> None: # pragma: no cover + self._api = value + + @property + def json_loads(self) -> Callable[..., Any]: # pragma: no cover + if not hasattr(self, "_json_loads"): + return json.loads + return self._json_loads + + @json_loads.setter + def json_loads(self, value: Callable[..., Any]) -> None: # pragma: no cover + self._json_loads = value # type: ignore + + @property + def json_dumps(self) -> Callable[..., str]: # pragma: no cover + if not hasattr(self, "_json_dumps"): + return json.dumps + return self._json_dumps + + @json_dumps.setter + def json_dumps(self, value: Callable[..., str]) -> None: # pragma: no cover + self._json_dumps = value # type: ignore + + @classmethod + def raise_for_status(cls, response: Response[T]) -> None: if response.ok: return raise TelegramAPIError(response.description) diff --git a/aiogram/api/methods/base.py b/aiogram/api/methods/base.py index 72eafa05..fc57a2ff 100644 --- a/aiogram/api/methods/base.py +++ b/aiogram/api/methods/base.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover from ..client.bot import Bot T = TypeVar("T") +DEFAULT_REQUEST_TIMEOUT_SECONDS = 60.0 class Request(BaseModel): @@ -55,6 +56,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]): def build_request(self) -> Request: # pragma: no cover pass + request_timeout: float = DEFAULT_REQUEST_TIMEOUT_SECONDS + + def dict(self, **kwargs: Any) -> Any: + # override dict of pydantic.BaseModel to overcome exporting request_timeout field + exclude = kwargs.pop("exclude", set()) + if isinstance(exclude, set): + exclude.add("request_timeout") + + return super().dict(exclude=exclude, **kwargs) + def build_response(self, data: Dict[str, Any]) -> Response[T]: # noinspection PyTypeChecker return Response[self.__returning__](**data) # type: ignore diff --git a/tests/test_api/test_client/test_session/test_aiohttp_session.py b/tests/test_api/test_client/test_session/test_aiohttp_session.py index 2587a686..e7716f9a 100644 --- a/tests/test_api/test_client/test_session/test_aiohttp_session.py +++ b/tests/test_api/test_client/test_session/test_aiohttp_session.py @@ -83,6 +83,22 @@ class TestAiohttpSession: aiohttp_session = await session.create_session() assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector) + @pytest.mark.asyncio + async def test_reset_connector(self): + session = AiohttpSession() + assert session._should_reset_connector + await session.create_session() + assert session._should_reset_connector is False + await session.close() + assert session._should_reset_connector is False + + assert session.proxy is None + session.proxy = "socks5://auth:auth@proxy.url/" + assert session._should_reset_connector + await session.create_session() + assert session._should_reset_connector is False + await session.close() + @pytest.mark.asyncio async def test_close_session(self): session = AiohttpSession() diff --git a/tests/test_api/test_client/test_session/test_base_session.py b/tests/test_api/test_client/test_session/test_base_session.py index 4c86f9da..9d775123 100644 --- a/tests/test_api/test_client/test_session/test_base_session.py +++ b/tests/test_api/test_client/test_session/test_base_session.py @@ -40,8 +40,10 @@ class TestBaseSession: base="http://example.com/{token}/{method}", file="http://example.com/{token}/file/{path{", ) - session = CustomSession(api=api) + session = CustomSession() + session.api = api assert session.api == api + assert "example.com" in session.api.base def test_prepare_value(self): session = CustomSession()