refactor(sessions):

remove BaseSession's initializer, add timeout ommitable field to base
method model
This commit is contained in:
mpa 2020-05-03 00:53:25 +04:00
parent 15bcc0ba9f
commit ea6a02bf97
5 changed files with 90 additions and 39 deletions

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Iterable,
List,
@ -15,11 +14,11 @@ from typing import (
cast,
)
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
from aiogram.api.methods import Request, TelegramMethod
from .base import PRODUCTION, BaseSession, TelegramAPIServer
from .base import BaseSession
T = TypeVar("T")
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
class AiohttpSession(BaseSession):
def __init__(
self,
api: TelegramAPIServer = PRODUCTION,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[_ProxyType] = None,
):
super(AiohttpSession, self).__init__(
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
)
def __init__(self, proxy: Optional[_ProxyType] = None):
self._session: Optional[ClientSession] = None
self._connector_type: Type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = {}
self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None
if self.proxy:
if proxy is not None:
try:
self._connector_type, self._connector_init = _prepare_connector(
cast(_ProxyType, self.proxy)
)
self._setup_proxy_connector(proxy)
except ImportError as exc: # pragma: no cover
raise UserWarning(
"In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/"
) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy
@property
def proxy(self) -> Optional[_ProxyType]:
return self._proxy
@proxy.setter
def proxy(self, proxy: _ProxyType) -> None:
self._setup_proxy_connector(proxy)
self._should_reset_connector = True
async def create_session(self) -> ClientSession:
if self._should_reset_connector:
await self.close()
if self._session is None or self._session.closed:
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
self._should_reset_connector = False
return self._session
@ -125,7 +132,7 @@ class AiohttpSession(BaseSession):
url = self.api.api_url(token=token, method=request.method)
form = self.build_form_data(request)
async with session.post(url, data=form) as resp:
async with session.post(url, data=form, timeout=call.request_timeout) as resp:
raw_result = await resp.json(loads=self.json_loads)
response = call.build_response(raw_result)
@ -136,9 +143,8 @@ class AiohttpSession(BaseSession):
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp:
async with session.get(url, timeout=timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk

View file

@ -16,26 +16,42 @@ PT = TypeVar("PT")
class BaseSession(abc.ABC):
def __init__(
self,
api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[PT] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
json_loads = json.loads
if json_dumps is None:
json_dumps = json.dumps
_api: TelegramAPIServer
_json_loads: Callable[..., Any]
_json_dumps: Callable[..., str]
self.api = api
self.json_loads = json_loads
self.json_dumps = json_dumps
self.proxy = proxy
@property
def api(self) -> TelegramAPIServer: # pragma: no cover
if not hasattr(self, "_api"):
return PRODUCTION
return self._api
def raise_for_status(self, response: Response[T]) -> None:
@api.setter
def api(self, value: TelegramAPIServer) -> None: # pragma: no cover
self._api = value
@property
def json_loads(self) -> Callable[..., Any]: # pragma: no cover
if not hasattr(self, "_json_loads"):
return json.loads
return self._json_loads
@json_loads.setter
def json_loads(self, value: Callable[..., Any]) -> None: # pragma: no cover
self._json_loads = value # type: ignore
@property
def json_dumps(self) -> Callable[..., str]: # pragma: no cover
if not hasattr(self, "_json_dumps"):
return json.dumps
return self._json_dumps
@json_dumps.setter
def json_dumps(self, value: Callable[..., str]) -> None: # pragma: no cover
self._json_dumps = value # type: ignore
@classmethod
def raise_for_status(cls, response: Response[T]) -> None:
if response.ok:
return
raise TelegramAPIError(response.description)

View file

@ -13,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover
from ..client.bot import Bot
T = TypeVar("T")
DEFAULT_REQUEST_TIMEOUT_SECONDS = 60.0
class Request(BaseModel):
@ -55,6 +56,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
def build_request(self) -> Request: # pragma: no cover
pass
request_timeout: float = DEFAULT_REQUEST_TIMEOUT_SECONDS
def dict(self, **kwargs: Any) -> Any:
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
exclude = kwargs.pop("exclude", set())
if isinstance(exclude, set):
exclude.add("request_timeout")
return super().dict(exclude=exclude, **kwargs)
def build_response(self, data: Dict[str, Any]) -> Response[T]:
# noinspection PyTypeChecker
return Response[self.__returning__](**data) # type: ignore

View file

@ -83,6 +83,22 @@ class TestAiohttpSession:
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
@pytest.mark.asyncio
async def test_reset_connector(self):
session = AiohttpSession()
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
assert session._should_reset_connector is False
assert session.proxy is None
session.proxy = "socks5://auth:auth@proxy.url/"
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
@pytest.mark.asyncio
async def test_close_session(self):
session = AiohttpSession()

View file

@ -40,8 +40,10 @@ class TestBaseSession:
base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{",
)
session = CustomSession(api=api)
session = CustomSession()
session.api = api
assert session.api == api
assert "example.com" in session.api.base
def test_prepare_value(self):
session = CustomSession()