mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
refactor(sessions):
remove BaseSession's initializer, add timeout ommitable field to base method model
This commit is contained in:
parent
15bcc0ba9f
commit
ea6a02bf97
5 changed files with 90 additions and 39 deletions
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
|
|
@ -15,11 +14,11 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector
|
||||
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
|
||||
|
||||
from aiogram.api.methods import Request, TelegramMethod
|
||||
|
||||
from .base import PRODUCTION, BaseSession, TelegramAPIServer
|
||||
from .base import BaseSession
|
||||
|
||||
T = TypeVar("T")
|
||||
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
|
||||
|
|
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
|
|||
|
||||
|
||||
class AiohttpSession(BaseSession):
|
||||
def __init__(
|
||||
self,
|
||||
api: TelegramAPIServer = PRODUCTION,
|
||||
json_loads: Optional[Callable[..., Any]] = None,
|
||||
json_dumps: Optional[Callable[..., str]] = None,
|
||||
proxy: Optional[_ProxyType] = None,
|
||||
):
|
||||
super(AiohttpSession, self).__init__(
|
||||
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
|
||||
)
|
||||
def __init__(self, proxy: Optional[_ProxyType] = None):
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._connector_type: Type[TCPConnector] = TCPConnector
|
||||
self._connector_init: Dict[str, Any] = {}
|
||||
self._should_reset_connector = True # flag determines connector state
|
||||
self._proxy: Optional[_ProxyType] = None
|
||||
|
||||
if self.proxy:
|
||||
if proxy is not None:
|
||||
try:
|
||||
self._connector_type, self._connector_init = _prepare_connector(
|
||||
cast(_ProxyType, self.proxy)
|
||||
)
|
||||
self._setup_proxy_connector(proxy)
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise UserWarning(
|
||||
"In order to use aiohttp client for proxy requests, install "
|
||||
"https://pypi.org/project/aiohttp-socks/"
|
||||
) from exc
|
||||
|
||||
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
|
||||
self._connector_type, self._connector_init = _prepare_connector(proxy)
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
def proxy(self) -> Optional[_ProxyType]:
|
||||
return self._proxy
|
||||
|
||||
@proxy.setter
|
||||
def proxy(self, proxy: _ProxyType) -> None:
|
||||
self._setup_proxy_connector(proxy)
|
||||
self._should_reset_connector = True
|
||||
|
||||
async def create_session(self) -> ClientSession:
|
||||
if self._should_reset_connector:
|
||||
await self.close()
|
||||
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
|
||||
self._should_reset_connector = False
|
||||
|
||||
return self._session
|
||||
|
||||
|
|
@ -125,7 +132,7 @@ class AiohttpSession(BaseSession):
|
|||
url = self.api.api_url(token=token, method=request.method)
|
||||
form = self.build_form_data(request)
|
||||
|
||||
async with session.post(url, data=form) as resp:
|
||||
async with session.post(url, data=form, timeout=call.request_timeout) as resp:
|
||||
raw_result = await resp.json(loads=self.json_loads)
|
||||
|
||||
response = call.build_response(raw_result)
|
||||
|
|
@ -136,9 +143,8 @@ class AiohttpSession(BaseSession):
|
|||
self, url: str, timeout: int, chunk_size: int
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
session = await self.create_session()
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
|
||||
async with session.get(url, timeout=client_timeout) as resp:
|
||||
async with session.get(url, timeout=timeout) as resp:
|
||||
async for chunk in resp.content.iter_chunked(chunk_size):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -16,26 +16,42 @@ PT = TypeVar("PT")
|
|||
|
||||
|
||||
class BaseSession(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
api: Optional[TelegramAPIServer] = None,
|
||||
json_loads: Optional[Callable[..., Any]] = None,
|
||||
json_dumps: Optional[Callable[..., str]] = None,
|
||||
proxy: Optional[PT] = None,
|
||||
) -> None:
|
||||
if api is None:
|
||||
api = PRODUCTION
|
||||
if json_loads is None:
|
||||
json_loads = json.loads
|
||||
if json_dumps is None:
|
||||
json_dumps = json.dumps
|
||||
_api: TelegramAPIServer
|
||||
_json_loads: Callable[..., Any]
|
||||
_json_dumps: Callable[..., str]
|
||||
|
||||
self.api = api
|
||||
self.json_loads = json_loads
|
||||
self.json_dumps = json_dumps
|
||||
self.proxy = proxy
|
||||
@property
|
||||
def api(self) -> TelegramAPIServer: # pragma: no cover
|
||||
if not hasattr(self, "_api"):
|
||||
return PRODUCTION
|
||||
return self._api
|
||||
|
||||
def raise_for_status(self, response: Response[T]) -> None:
|
||||
@api.setter
|
||||
def api(self, value: TelegramAPIServer) -> None: # pragma: no cover
|
||||
self._api = value
|
||||
|
||||
@property
|
||||
def json_loads(self) -> Callable[..., Any]: # pragma: no cover
|
||||
if not hasattr(self, "_json_loads"):
|
||||
return json.loads
|
||||
return self._json_loads
|
||||
|
||||
@json_loads.setter
|
||||
def json_loads(self, value: Callable[..., Any]) -> None: # pragma: no cover
|
||||
self._json_loads = value # type: ignore
|
||||
|
||||
@property
|
||||
def json_dumps(self) -> Callable[..., str]: # pragma: no cover
|
||||
if not hasattr(self, "_json_dumps"):
|
||||
return json.dumps
|
||||
return self._json_dumps
|
||||
|
||||
@json_dumps.setter
|
||||
def json_dumps(self, value: Callable[..., str]) -> None: # pragma: no cover
|
||||
self._json_dumps = value # type: ignore
|
||||
|
||||
@classmethod
|
||||
def raise_for_status(cls, response: Response[T]) -> None:
|
||||
if response.ok:
|
||||
return
|
||||
raise TelegramAPIError(response.description)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|||
from ..client.bot import Bot
|
||||
|
||||
T = TypeVar("T")
|
||||
DEFAULT_REQUEST_TIMEOUT_SECONDS = 60.0
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
|
|
@ -55,6 +56,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
|
|||
def build_request(self) -> Request: # pragma: no cover
|
||||
pass
|
||||
|
||||
request_timeout: float = DEFAULT_REQUEST_TIMEOUT_SECONDS
|
||||
|
||||
def dict(self, **kwargs: Any) -> Any:
|
||||
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
|
||||
exclude = kwargs.pop("exclude", set())
|
||||
if isinstance(exclude, set):
|
||||
exclude.add("request_timeout")
|
||||
|
||||
return super().dict(exclude=exclude, **kwargs)
|
||||
|
||||
def build_response(self, data: Dict[str, Any]) -> Response[T]:
|
||||
# noinspection PyTypeChecker
|
||||
return Response[self.__returning__](**data) # type: ignore
|
||||
|
|
|
|||
|
|
@ -83,6 +83,22 @@ class TestAiohttpSession:
|
|||
aiohttp_session = await session.create_session()
|
||||
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_connector(self):
|
||||
session = AiohttpSession()
|
||||
assert session._should_reset_connector
|
||||
await session.create_session()
|
||||
assert session._should_reset_connector is False
|
||||
await session.close()
|
||||
assert session._should_reset_connector is False
|
||||
|
||||
assert session.proxy is None
|
||||
session.proxy = "socks5://auth:auth@proxy.url/"
|
||||
assert session._should_reset_connector
|
||||
await session.create_session()
|
||||
assert session._should_reset_connector is False
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_session(self):
|
||||
session = AiohttpSession()
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ class TestBaseSession:
|
|||
base="http://example.com/{token}/{method}",
|
||||
file="http://example.com/{token}/file/{path{",
|
||||
)
|
||||
session = CustomSession(api=api)
|
||||
session = CustomSession()
|
||||
session.api = api
|
||||
assert session.api == api
|
||||
assert "example.com" in session.api.base
|
||||
|
||||
def test_prepare_value(self):
|
||||
session = CustomSession()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue