refactor(sessions):

remove BaseSession's initializer, add timeout ommitable field to base
method model
This commit is contained in:
mpa 2020-05-03 00:53:25 +04:00
parent 15bcc0ba9f
commit ea6a02bf97
5 changed files with 90 additions and 39 deletions

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -15,11 +14,11 @@ from typing import (
cast, cast,
) )
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
from aiogram.api.methods import Request, TelegramMethod from aiogram.api.methods import Request, TelegramMethod
from .base import PRODUCTION, BaseSession, TelegramAPIServer from .base import BaseSession
T = TypeVar("T") T = TypeVar("T")
_ProxyBasic = Union[str, Tuple[str, BasicAuth]] _ProxyBasic = Union[str, Tuple[str, BasicAuth]]
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
class AiohttpSession(BaseSession): class AiohttpSession(BaseSession):
def __init__( def __init__(self, proxy: Optional[_ProxyType] = None):
self,
api: TelegramAPIServer = PRODUCTION,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[_ProxyType] = None,
):
super(AiohttpSession, self).__init__(
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
)
self._session: Optional[ClientSession] = None self._session: Optional[ClientSession] = None
self._connector_type: Type[TCPConnector] = TCPConnector self._connector_type: Type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = {} self._connector_init: Dict[str, Any] = {}
self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None
if self.proxy: if proxy is not None:
try: try:
self._connector_type, self._connector_init = _prepare_connector( self._setup_proxy_connector(proxy)
cast(_ProxyType, self.proxy)
)
except ImportError as exc: # pragma: no cover except ImportError as exc: # pragma: no cover
raise UserWarning( raise UserWarning(
"In order to use aiohttp client for proxy requests, install " "In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/" "https://pypi.org/project/aiohttp-socks/"
) from exc ) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy
@property
def proxy(self) -> Optional[_ProxyType]:
return self._proxy
@proxy.setter
def proxy(self, proxy: _ProxyType) -> None:
self._setup_proxy_connector(proxy)
self._should_reset_connector = True
async def create_session(self) -> ClientSession: async def create_session(self) -> ClientSession:
if self._should_reset_connector:
await self.close()
if self._session is None or self._session.closed: if self._session is None or self._session.closed:
self._session = ClientSession(connector=self._connector_type(**self._connector_init)) self._session = ClientSession(connector=self._connector_type(**self._connector_init))
self._should_reset_connector = False
return self._session return self._session
@ -125,7 +132,7 @@ class AiohttpSession(BaseSession):
url = self.api.api_url(token=token, method=request.method) url = self.api.api_url(token=token, method=request.method)
form = self.build_form_data(request) form = self.build_form_data(request)
async with session.post(url, data=form) as resp: async with session.post(url, data=form, timeout=call.request_timeout) as resp:
raw_result = await resp.json(loads=self.json_loads) raw_result = await resp.json(loads=self.json_loads)
response = call.build_response(raw_result) response = call.build_response(raw_result)
@ -136,9 +143,8 @@ class AiohttpSession(BaseSession):
self, url: str, timeout: int, chunk_size: int self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
session = await self.create_session() session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp: async with session.get(url, timeout=timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size): async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk yield chunk

View file

@ -16,26 +16,42 @@ PT = TypeVar("PT")
class BaseSession(abc.ABC): class BaseSession(abc.ABC):
def __init__( _api: TelegramAPIServer
self, _json_loads: Callable[..., Any]
api: Optional[TelegramAPIServer] = None, _json_dumps: Callable[..., str]
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[PT] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
json_loads = json.loads
if json_dumps is None:
json_dumps = json.dumps
self.api = api @property
self.json_loads = json_loads def api(self) -> TelegramAPIServer: # pragma: no cover
self.json_dumps = json_dumps if not hasattr(self, "_api"):
self.proxy = proxy return PRODUCTION
return self._api
def raise_for_status(self, response: Response[T]) -> None: @api.setter
def api(self, value: TelegramAPIServer) -> None: # pragma: no cover
self._api = value
@property
def json_loads(self) -> Callable[..., Any]: # pragma: no cover
if not hasattr(self, "_json_loads"):
return json.loads
return self._json_loads
@json_loads.setter
def json_loads(self, value: Callable[..., Any]) -> None: # pragma: no cover
self._json_loads = value # type: ignore
@property
def json_dumps(self) -> Callable[..., str]: # pragma: no cover
if not hasattr(self, "_json_dumps"):
return json.dumps
return self._json_dumps
@json_dumps.setter
def json_dumps(self, value: Callable[..., str]) -> None: # pragma: no cover
self._json_dumps = value # type: ignore
@classmethod
def raise_for_status(cls, response: Response[T]) -> None:
if response.ok: if response.ok:
return return
raise TelegramAPIError(response.description) raise TelegramAPIError(response.description)

View file

@ -13,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover
from ..client.bot import Bot from ..client.bot import Bot
T = TypeVar("T") T = TypeVar("T")
DEFAULT_REQUEST_TIMEOUT_SECONDS = 60.0
class Request(BaseModel): class Request(BaseModel):
@ -55,6 +56,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
def build_request(self) -> Request: # pragma: no cover def build_request(self) -> Request: # pragma: no cover
pass pass
request_timeout: float = DEFAULT_REQUEST_TIMEOUT_SECONDS
def dict(self, **kwargs: Any) -> Any:
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
exclude = kwargs.pop("exclude", set())
if isinstance(exclude, set):
exclude.add("request_timeout")
return super().dict(exclude=exclude, **kwargs)
def build_response(self, data: Dict[str, Any]) -> Response[T]: def build_response(self, data: Dict[str, Any]) -> Response[T]:
# noinspection PyTypeChecker # noinspection PyTypeChecker
return Response[self.__returning__](**data) # type: ignore return Response[self.__returning__](**data) # type: ignore

View file

@ -83,6 +83,22 @@ class TestAiohttpSession:
aiohttp_session = await session.create_session() aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector) assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
@pytest.mark.asyncio
async def test_reset_connector(self):
session = AiohttpSession()
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
assert session._should_reset_connector is False
assert session.proxy is None
session.proxy = "socks5://auth:auth@proxy.url/"
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close_session(self): async def test_close_session(self):
session = AiohttpSession() session = AiohttpSession()

View file

@ -40,8 +40,10 @@ class TestBaseSession:
base="http://example.com/{token}/{method}", base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{", file="http://example.com/{token}/file/{path{",
) )
session = CustomSession(api=api) session = CustomSession()
session.api = api
assert session.api == api assert session.api == api
assert "example.com" in session.api.base
def test_prepare_value(self): def test_prepare_value(self):
session = CustomSession() session = CustomSession()