Merge pull request #318 from aiogram/dev-3.x-refactor-sessions

Dev 3.x refactor sessions
This commit is contained in:
Alex Root Junior 2020-05-09 23:49:54 +03:00 committed by GitHub
commit 05dd42712d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 151 additions and 44 deletions

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -15,11 +14,11 @@ from typing import (
cast, cast,
) )
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
from aiogram.api.methods import Request, TelegramMethod from aiogram.api.methods import Request, TelegramMethod
from .base import PRODUCTION, BaseSession, TelegramAPIServer from .base import BaseSession
T = TypeVar("T") T = TypeVar("T")
_ProxyBasic = Union[str, Tuple[str, BasicAuth]] _ProxyBasic = Union[str, Tuple[str, BasicAuth]]
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
class AiohttpSession(BaseSession): class AiohttpSession(BaseSession):
def __init__( def __init__(self, proxy: Optional[_ProxyType] = None):
self,
api: TelegramAPIServer = PRODUCTION,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[_ProxyType] = None,
):
super(AiohttpSession, self).__init__(
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
)
self._session: Optional[ClientSession] = None self._session: Optional[ClientSession] = None
self._connector_type: Type[TCPConnector] = TCPConnector self._connector_type: Type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = {} self._connector_init: Dict[str, Any] = {}
self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None
if self.proxy: if proxy is not None:
try: try:
self._connector_type, self._connector_init = _prepare_connector( self._setup_proxy_connector(proxy)
cast(_ProxyType, self.proxy)
)
except ImportError as exc: # pragma: no cover except ImportError as exc: # pragma: no cover
raise UserWarning( raise UserWarning(
"In order to use aiohttp client for proxy requests, install " "In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/" "https://pypi.org/project/aiohttp-socks/"
) from exc ) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy
@property
def proxy(self) -> Optional[_ProxyType]:
return self._proxy
@proxy.setter
def proxy(self, proxy: _ProxyType) -> None:
self._setup_proxy_connector(proxy)
self._should_reset_connector = True
async def create_session(self) -> ClientSession: async def create_session(self) -> ClientSession:
if self._should_reset_connector:
await self.close()
if self._session is None or self._session.closed: if self._session is None or self._session.closed:
self._session = ClientSession(connector=self._connector_type(**self._connector_init)) self._session = ClientSession(connector=self._connector_type(**self._connector_init))
self._should_reset_connector = False
return self._session return self._session
@ -125,7 +132,9 @@ class AiohttpSession(BaseSession):
url = self.api.api_url(token=token, method=request.method) url = self.api.api_url(token=token, method=request.method)
form = self.build_form_data(request) form = self.build_form_data(request)
async with session.post(url, data=form) as resp: async with session.post(
url, data=form, timeout=call.request_timeout or self.timeout
) as resp:
raw_result = await resp.json(loads=self.json_loads) raw_result = await resp.json(loads=self.json_loads)
response = call.build_response(raw_result) response = call.build_response(raw_result)
@ -136,9 +145,8 @@ class AiohttpSession(BaseSession):
self, url: str, timeout: int, chunk_size: int self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
session = await self.create_session() session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp: async with session.get(url, timeout=timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size): async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk yield chunk

View file

@ -4,7 +4,7 @@ import abc
import datetime import datetime
import json import json
from types import TracebackType from types import TracebackType
from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union
from aiogram.utils.exceptions import TelegramAPIError from aiogram.utils.exceptions import TelegramAPIError
@ -12,30 +12,58 @@ from ...methods import Response, TelegramMethod
from ..telegram import PRODUCTION, TelegramAPIServer from ..telegram import PRODUCTION, TelegramAPIServer
T = TypeVar("T") T = TypeVar("T")
PT = TypeVar("PT") _JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
class BaseSession(abc.ABC): class BaseSession(abc.ABC):
def __init__( # global session timeout
self, default_timeout: ClassVar[float] = 60.0
api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[PT] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
json_loads = json.loads
if json_dumps is None:
json_dumps = json.dumps
self.api = api _api: TelegramAPIServer
self.json_loads = json_loads _json_loads: _JsonLoads
self.json_dumps = json_dumps _json_dumps: _JsonDumps
self.proxy = proxy _timeout: float
def raise_for_status(self, response: Response[T]) -> None: @property
def api(self) -> TelegramAPIServer:
return getattr(self, "_api", PRODUCTION) # type: ignore
@api.setter
def api(self, value: TelegramAPIServer) -> None:
self._api = value
@property
def json_loads(self) -> _JsonLoads:
return getattr(self, "_json_loads", json.loads) # type: ignore
@json_loads.setter
def json_loads(self, value: _JsonLoads) -> None:
self._json_loads = value # type: ignore
@property
def json_dumps(self) -> _JsonDumps:
return getattr(self, "_json_dumps", json.dumps) # type: ignore
@json_dumps.setter
def json_dumps(self, value: _JsonDumps) -> None:
self._json_dumps = value # type: ignore
@property
def timeout(self) -> float:
return getattr(self, "_timeout", self.__class__.default_timeout) # type: ignore
@timeout.setter
def timeout(self, value: float) -> None:
self._timeout = value
@timeout.deleter
def timeout(self) -> None:
if hasattr(self, "_timeout"):
del self._timeout
@classmethod
def raise_for_status(cls, response: Response[T]) -> None:
if response.ok: if response.ok:
return return
raise TelegramAPIError(response.description) raise TelegramAPIError(response.description)

View file

@ -55,6 +55,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
def build_request(self) -> Request: # pragma: no cover def build_request(self) -> Request: # pragma: no cover
pass pass
request_timeout: Optional[float] = None
def dict(self, **kwargs: Any) -> Any:
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
exclude = kwargs.pop("exclude", set())
if isinstance(exclude, set):
exclude.add("request_timeout")
return super().dict(exclude=exclude, **kwargs)
def build_response(self, data: Dict[str, Any]) -> Response[T]: def build_response(self, data: Dict[str, Any]) -> Response[T]:
# noinspection PyTypeChecker # noinspection PyTypeChecker
return Response[self.__returning__](**data) # type: ignore return Response[self.__returning__](**data) # type: ignore

View file

@ -1,6 +1,6 @@
# Aiohttp session # Aiohttp session
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp]('https://pypi.org/project/aiohttp/') AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp](https://pypi.org/project/aiohttp/ "PyPi repository"){target=_blank}
Currently `AiohttpSession` is a default session used in `aiogram.Bot` Currently `AiohttpSession` is a default session used in `aiogram.Bot`
@ -17,7 +17,7 @@ Bot('token', session=session)
## Proxy requests in AiohttpSession ## Proxy requests in AiohttpSession
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks]('https://pypi.org/project/aiohttp-socks/') In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks](https://pypi.org/project/aiohttp-socks/ "PyPi repository"){target=_blank}
Binding session to bot: Binding session to bot:
```python ```python

View file

@ -35,6 +35,7 @@ markdown_extensions:
- pymdownx.inlinehilite - pymdownx.inlinehilite
- markdown_include.include: - markdown_include.include:
base_path: docs base_path: docs
- attr_list
nav: nav:
- index.md - index.md

View file

@ -83,6 +83,22 @@ class TestAiohttpSession:
aiohttp_session = await session.create_session() aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector) assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
@pytest.mark.asyncio
async def test_reset_connector(self):
session = AiohttpSession()
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
assert session._should_reset_connector is False
assert session.proxy is None
session.proxy = "socks5://auth:auth@proxy.url/"
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close_session(self): async def test_close_session(self):
session = AiohttpSession() session = AiohttpSession()

View file

@ -1,4 +1,5 @@
import datetime import datetime
import json
from typing import AsyncContextManager, AsyncGenerator from typing import AsyncContextManager, AsyncGenerator
import pytest import pytest
@ -35,13 +36,56 @@ class TestBaseSession:
session = CustomSession() session = CustomSession()
assert session.api == PRODUCTION assert session.api == PRODUCTION
def test_default_props(self):
session = CustomSession()
assert session.api == PRODUCTION
assert session.json_loads == json.loads
assert session.json_dumps == json.dumps
def custom_loads(*_):
return json.loads
def custom_dumps(*_):
return json.dumps
session.json_dumps = custom_dumps
assert session.json_dumps == custom_dumps == session._json_dumps
session.json_loads = custom_loads
assert session.json_loads == custom_loads == session._json_loads
different_session = CustomSession()
assert all(
not hasattr(different_session, attr) for attr in ("_json_loads", "_json_dumps", "_api")
)
def test_timeout(self):
session = CustomSession()
assert session.timeout == session.default_timeout == CustomSession.default_timeout
session.default_timeout = float(65.0_0) # mypy will complain
assert session.timeout != session.default_timeout
CustomSession.default_timeout = float(68.0_0)
assert session.timeout == CustomSession.default_timeout
session.timeout = float(71.0_0)
assert session.timeout != session.default_timeout
del session.timeout
CustomSession.default_timeout = session.default_timeout + 100
assert (
session.timeout != BaseSession.default_timeout
and session.timeout == CustomSession.default_timeout
)
def test_init_custom_api(self): def test_init_custom_api(self):
api = TelegramAPIServer( api = TelegramAPIServer(
base="http://example.com/{token}/{method}", base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{", file="http://example.com/{token}/file/{path}",
) )
session = CustomSession(api=api) session = CustomSession()
session.api = api
assert session.api == api assert session.api == api
assert "example.com" in session.api.base
def test_prepare_value(self): def test_prepare_value(self):
session = CustomSession() session = CustomSession()