Merge pull request #318 from aiogram/dev-3.x-refactor-sessions

Dev 3.x refactor sessions
This commit is contained in:
Alex Root Junior 2020-05-09 23:49:54 +03:00 committed by GitHub
commit 05dd42712d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 151 additions and 44 deletions

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Iterable,
List,
@ -15,11 +14,11 @@ from typing import (
cast,
)
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
from aiogram.api.methods import Request, TelegramMethod
from .base import PRODUCTION, BaseSession, TelegramAPIServer
from .base import BaseSession
T = TypeVar("T")
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
class AiohttpSession(BaseSession):
def __init__(
self,
api: TelegramAPIServer = PRODUCTION,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[_ProxyType] = None,
):
super(AiohttpSession, self).__init__(
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
)
def __init__(self, proxy: Optional[_ProxyType] = None):
self._session: Optional[ClientSession] = None
self._connector_type: Type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = {}
self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None
if self.proxy:
if proxy is not None:
try:
self._connector_type, self._connector_init = _prepare_connector(
cast(_ProxyType, self.proxy)
)
self._setup_proxy_connector(proxy)
except ImportError as exc: # pragma: no cover
raise UserWarning(
"In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/"
) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy
@property
def proxy(self) -> Optional[_ProxyType]:
return self._proxy
@proxy.setter
def proxy(self, proxy: _ProxyType) -> None:
self._setup_proxy_connector(proxy)
self._should_reset_connector = True
async def create_session(self) -> ClientSession:
if self._should_reset_connector:
await self.close()
if self._session is None or self._session.closed:
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
self._should_reset_connector = False
return self._session
@ -125,7 +132,9 @@ class AiohttpSession(BaseSession):
url = self.api.api_url(token=token, method=request.method)
form = self.build_form_data(request)
async with session.post(url, data=form) as resp:
async with session.post(
url, data=form, timeout=call.request_timeout or self.timeout
) as resp:
raw_result = await resp.json(loads=self.json_loads)
response = call.build_response(raw_result)
@ -136,9 +145,8 @@ class AiohttpSession(BaseSession):
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp:
async with session.get(url, timeout=timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk

View file

@ -4,7 +4,7 @@ import abc
import datetime
import json
from types import TracebackType
from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union
from aiogram.utils.exceptions import TelegramAPIError
@ -12,30 +12,58 @@ from ...methods import Response, TelegramMethod
from ..telegram import PRODUCTION, TelegramAPIServer
T = TypeVar("T")
PT = TypeVar("PT")
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
class BaseSession(abc.ABC):
def __init__(
self,
api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable[..., Any]] = None,
json_dumps: Optional[Callable[..., str]] = None,
proxy: Optional[PT] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
json_loads = json.loads
if json_dumps is None:
json_dumps = json.dumps
# global session timeout
default_timeout: ClassVar[float] = 60.0
self.api = api
self.json_loads = json_loads
self.json_dumps = json_dumps
self.proxy = proxy
_api: TelegramAPIServer
_json_loads: _JsonLoads
_json_dumps: _JsonDumps
_timeout: float
def raise_for_status(self, response: Response[T]) -> None:
@property
def api(self) -> TelegramAPIServer:
return getattr(self, "_api", PRODUCTION) # type: ignore
@api.setter
def api(self, value: TelegramAPIServer) -> None:
self._api = value
@property
def json_loads(self) -> _JsonLoads:
return getattr(self, "_json_loads", json.loads) # type: ignore
@json_loads.setter
def json_loads(self, value: _JsonLoads) -> None:
self._json_loads = value # type: ignore
@property
def json_dumps(self) -> _JsonDumps:
return getattr(self, "_json_dumps", json.dumps) # type: ignore
@json_dumps.setter
def json_dumps(self, value: _JsonDumps) -> None:
self._json_dumps = value # type: ignore
@property
def timeout(self) -> float:
return getattr(self, "_timeout", self.__class__.default_timeout) # type: ignore
@timeout.setter
def timeout(self, value: float) -> None:
self._timeout = value
@timeout.deleter
def timeout(self) -> None:
if hasattr(self, "_timeout"):
del self._timeout
@classmethod
def raise_for_status(cls, response: Response[T]) -> None:
if response.ok:
return
raise TelegramAPIError(response.description)

View file

@ -55,6 +55,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
def build_request(self) -> Request: # pragma: no cover
pass
request_timeout: Optional[float] = None
def dict(self, **kwargs: Any) -> Any:
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
exclude = kwargs.pop("exclude", set())
if isinstance(exclude, set):
exclude.add("request_timeout")
return super().dict(exclude=exclude, **kwargs)
def build_response(self, data: Dict[str, Any]) -> Response[T]:
# noinspection PyTypeChecker
return Response[self.__returning__](**data) # type: ignore

View file

@ -1,6 +1,6 @@
# Aiohttp session
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp]('https://pypi.org/project/aiohttp/')
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp](https://pypi.org/project/aiohttp/ "PyPi repository"){target=_blank}
Currently `AiohttpSession` is a default session used in `aiogram.Bot`
@ -17,7 +17,7 @@ Bot('token', session=session)
## Proxy requests in AiohttpSession
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks]('https://pypi.org/project/aiohttp-socks/')
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks](https://pypi.org/project/aiohttp-socks/ "PyPi repository"){target=_blank}
Binding session to bot:
```python

View file

@ -35,6 +35,7 @@ markdown_extensions:
- pymdownx.inlinehilite
- markdown_include.include:
base_path: docs
- attr_list
nav:
- index.md

View file

@ -83,6 +83,22 @@ class TestAiohttpSession:
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
@pytest.mark.asyncio
async def test_reset_connector(self):
session = AiohttpSession()
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
assert session._should_reset_connector is False
assert session.proxy is None
session.proxy = "socks5://auth:auth@proxy.url/"
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
@pytest.mark.asyncio
async def test_close_session(self):
session = AiohttpSession()

View file

@ -1,4 +1,5 @@
import datetime
import json
from typing import AsyncContextManager, AsyncGenerator
import pytest
@ -35,13 +36,56 @@ class TestBaseSession:
session = CustomSession()
assert session.api == PRODUCTION
def test_default_props(self):
session = CustomSession()
assert session.api == PRODUCTION
assert session.json_loads == json.loads
assert session.json_dumps == json.dumps
def custom_loads(*_):
return json.loads
def custom_dumps(*_):
return json.dumps
session.json_dumps = custom_dumps
assert session.json_dumps == custom_dumps == session._json_dumps
session.json_loads = custom_loads
assert session.json_loads == custom_loads == session._json_loads
different_session = CustomSession()
assert all(
not hasattr(different_session, attr) for attr in ("_json_loads", "_json_dumps", "_api")
)
def test_timeout(self):
session = CustomSession()
assert session.timeout == session.default_timeout == CustomSession.default_timeout
session.default_timeout = float(65.0_0) # mypy will complain
assert session.timeout != session.default_timeout
CustomSession.default_timeout = float(68.0_0)
assert session.timeout == CustomSession.default_timeout
session.timeout = float(71.0_0)
assert session.timeout != session.default_timeout
del session.timeout
CustomSession.default_timeout = session.default_timeout + 100
assert (
session.timeout != BaseSession.default_timeout
and session.timeout == CustomSession.default_timeout
)
def test_init_custom_api(self):
api = TelegramAPIServer(
base="http://example.com/{token}/{method}",
file="http://example.com/{token}/file/{path{",
file="http://example.com/{token}/file/{path}",
)
session = CustomSession(api=api)
session = CustomSession()
session.api = api
assert session.api == api
assert "example.com" in session.api.base
def test_prepare_value(self):
session = CustomSession()