mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 18:01:04 +00:00
Merge pull request #318 from aiogram/dev-3.x-refactor-sessions
Dev 3.x refactor sessions
This commit is contained in:
commit
05dd42712d
7 changed files with 151 additions and 44 deletions
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
|
|
@ -15,11 +14,11 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector
|
||||
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
|
||||
|
||||
from aiogram.api.methods import Request, TelegramMethod
|
||||
|
||||
from .base import PRODUCTION, BaseSession, TelegramAPIServer
|
||||
from .base import BaseSession
|
||||
|
||||
T = TypeVar("T")
|
||||
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
|
||||
|
|
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
|
|||
|
||||
|
||||
class AiohttpSession(BaseSession):
|
||||
def __init__(
|
||||
self,
|
||||
api: TelegramAPIServer = PRODUCTION,
|
||||
json_loads: Optional[Callable[..., Any]] = None,
|
||||
json_dumps: Optional[Callable[..., str]] = None,
|
||||
proxy: Optional[_ProxyType] = None,
|
||||
):
|
||||
super(AiohttpSession, self).__init__(
|
||||
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
|
||||
)
|
||||
def __init__(self, proxy: Optional[_ProxyType] = None):
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._connector_type: Type[TCPConnector] = TCPConnector
|
||||
self._connector_init: Dict[str, Any] = {}
|
||||
self._should_reset_connector = True # flag determines connector state
|
||||
self._proxy: Optional[_ProxyType] = None
|
||||
|
||||
if self.proxy:
|
||||
if proxy is not None:
|
||||
try:
|
||||
self._connector_type, self._connector_init = _prepare_connector(
|
||||
cast(_ProxyType, self.proxy)
|
||||
)
|
||||
self._setup_proxy_connector(proxy)
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise UserWarning(
|
||||
"In order to use aiohttp client for proxy requests, install "
|
||||
"https://pypi.org/project/aiohttp-socks/"
|
||||
) from exc
|
||||
|
||||
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
|
||||
self._connector_type, self._connector_init = _prepare_connector(proxy)
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
def proxy(self) -> Optional[_ProxyType]:
|
||||
return self._proxy
|
||||
|
||||
@proxy.setter
|
||||
def proxy(self, proxy: _ProxyType) -> None:
|
||||
self._setup_proxy_connector(proxy)
|
||||
self._should_reset_connector = True
|
||||
|
||||
async def create_session(self) -> ClientSession:
|
||||
if self._should_reset_connector:
|
||||
await self.close()
|
||||
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
|
||||
self._should_reset_connector = False
|
||||
|
||||
return self._session
|
||||
|
||||
|
|
@ -125,7 +132,9 @@ class AiohttpSession(BaseSession):
|
|||
url = self.api.api_url(token=token, method=request.method)
|
||||
form = self.build_form_data(request)
|
||||
|
||||
async with session.post(url, data=form) as resp:
|
||||
async with session.post(
|
||||
url, data=form, timeout=call.request_timeout or self.timeout
|
||||
) as resp:
|
||||
raw_result = await resp.json(loads=self.json_loads)
|
||||
|
||||
response = call.build_response(raw_result)
|
||||
|
|
@ -136,9 +145,8 @@ class AiohttpSession(BaseSession):
|
|||
self, url: str, timeout: int, chunk_size: int
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
session = await self.create_session()
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
|
||||
async with session.get(url, timeout=client_timeout) as resp:
|
||||
async with session.get(url, timeout=timeout) as resp:
|
||||
async for chunk in resp.content.iter_chunked(chunk_size):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import abc
|
|||
import datetime
|
||||
import json
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union
|
||||
from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union
|
||||
|
||||
from aiogram.utils.exceptions import TelegramAPIError
|
||||
|
||||
|
|
@ -12,30 +12,58 @@ from ...methods import Response, TelegramMethod
|
|||
from ..telegram import PRODUCTION, TelegramAPIServer
|
||||
|
||||
T = TypeVar("T")
|
||||
PT = TypeVar("PT")
|
||||
_JsonLoads = Callable[..., Any]
|
||||
_JsonDumps = Callable[..., str]
|
||||
|
||||
|
||||
class BaseSession(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
api: Optional[TelegramAPIServer] = None,
|
||||
json_loads: Optional[Callable[..., Any]] = None,
|
||||
json_dumps: Optional[Callable[..., str]] = None,
|
||||
proxy: Optional[PT] = None,
|
||||
) -> None:
|
||||
if api is None:
|
||||
api = PRODUCTION
|
||||
if json_loads is None:
|
||||
json_loads = json.loads
|
||||
if json_dumps is None:
|
||||
json_dumps = json.dumps
|
||||
# global session timeout
|
||||
default_timeout: ClassVar[float] = 60.0
|
||||
|
||||
self.api = api
|
||||
self.json_loads = json_loads
|
||||
self.json_dumps = json_dumps
|
||||
self.proxy = proxy
|
||||
_api: TelegramAPIServer
|
||||
_json_loads: _JsonLoads
|
||||
_json_dumps: _JsonDumps
|
||||
_timeout: float
|
||||
|
||||
def raise_for_status(self, response: Response[T]) -> None:
|
||||
@property
|
||||
def api(self) -> TelegramAPIServer:
|
||||
return getattr(self, "_api", PRODUCTION) # type: ignore
|
||||
|
||||
@api.setter
|
||||
def api(self, value: TelegramAPIServer) -> None:
|
||||
self._api = value
|
||||
|
||||
@property
|
||||
def json_loads(self) -> _JsonLoads:
|
||||
return getattr(self, "_json_loads", json.loads) # type: ignore
|
||||
|
||||
@json_loads.setter
|
||||
def json_loads(self, value: _JsonLoads) -> None:
|
||||
self._json_loads = value # type: ignore
|
||||
|
||||
@property
|
||||
def json_dumps(self) -> _JsonDumps:
|
||||
return getattr(self, "_json_dumps", json.dumps) # type: ignore
|
||||
|
||||
@json_dumps.setter
|
||||
def json_dumps(self, value: _JsonDumps) -> None:
|
||||
self._json_dumps = value # type: ignore
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return getattr(self, "_timeout", self.__class__.default_timeout) # type: ignore
|
||||
|
||||
@timeout.setter
|
||||
def timeout(self, value: float) -> None:
|
||||
self._timeout = value
|
||||
|
||||
@timeout.deleter
|
||||
def timeout(self) -> None:
|
||||
if hasattr(self, "_timeout"):
|
||||
del self._timeout
|
||||
|
||||
@classmethod
|
||||
def raise_for_status(cls, response: Response[T]) -> None:
|
||||
if response.ok:
|
||||
return
|
||||
raise TelegramAPIError(response.description)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
|
|||
def build_request(self) -> Request: # pragma: no cover
|
||||
pass
|
||||
|
||||
request_timeout: Optional[float] = None
|
||||
|
||||
def dict(self, **kwargs: Any) -> Any:
|
||||
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
|
||||
exclude = kwargs.pop("exclude", set())
|
||||
if isinstance(exclude, set):
|
||||
exclude.add("request_timeout")
|
||||
|
||||
return super().dict(exclude=exclude, **kwargs)
|
||||
|
||||
def build_response(self, data: Dict[str, Any]) -> Response[T]:
|
||||
# noinspection PyTypeChecker
|
||||
return Response[self.__returning__](**data) # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Aiohttp session
|
||||
|
||||
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp]('https://pypi.org/project/aiohttp/')
|
||||
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp](https://pypi.org/project/aiohttp/ "PyPi repository"){target=_blank}
|
||||
|
||||
Currently `AiohttpSession` is a default session used in `aiogram.Bot`
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ Bot('token', session=session)
|
|||
|
||||
## Proxy requests in AiohttpSession
|
||||
|
||||
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks]('https://pypi.org/project/aiohttp-socks/')
|
||||
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks](https://pypi.org/project/aiohttp-socks/ "PyPi repository"){target=_blank}
|
||||
|
||||
Binding session to bot:
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ markdown_extensions:
|
|||
- pymdownx.inlinehilite
|
||||
- markdown_include.include:
|
||||
base_path: docs
|
||||
- attr_list
|
||||
|
||||
nav:
|
||||
- index.md
|
||||
|
|
|
|||
|
|
@ -83,6 +83,22 @@ class TestAiohttpSession:
|
|||
aiohttp_session = await session.create_session()
|
||||
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_connector(self):
|
||||
session = AiohttpSession()
|
||||
assert session._should_reset_connector
|
||||
await session.create_session()
|
||||
assert session._should_reset_connector is False
|
||||
await session.close()
|
||||
assert session._should_reset_connector is False
|
||||
|
||||
assert session.proxy is None
|
||||
session.proxy = "socks5://auth:auth@proxy.url/"
|
||||
assert session._should_reset_connector
|
||||
await session.create_session()
|
||||
assert session._should_reset_connector is False
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_session(self):
|
||||
session = AiohttpSession()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
import json
|
||||
from typing import AsyncContextManager, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
|
@ -35,13 +36,56 @@ class TestBaseSession:
|
|||
session = CustomSession()
|
||||
assert session.api == PRODUCTION
|
||||
|
||||
def test_default_props(self):
|
||||
session = CustomSession()
|
||||
assert session.api == PRODUCTION
|
||||
assert session.json_loads == json.loads
|
||||
assert session.json_dumps == json.dumps
|
||||
|
||||
def custom_loads(*_):
|
||||
return json.loads
|
||||
|
||||
def custom_dumps(*_):
|
||||
return json.dumps
|
||||
|
||||
session.json_dumps = custom_dumps
|
||||
assert session.json_dumps == custom_dumps == session._json_dumps
|
||||
session.json_loads = custom_loads
|
||||
assert session.json_loads == custom_loads == session._json_loads
|
||||
|
||||
different_session = CustomSession()
|
||||
assert all(
|
||||
not hasattr(different_session, attr) for attr in ("_json_loads", "_json_dumps", "_api")
|
||||
)
|
||||
|
||||
def test_timeout(self):
|
||||
session = CustomSession()
|
||||
assert session.timeout == session.default_timeout == CustomSession.default_timeout
|
||||
|
||||
session.default_timeout = float(65.0_0) # mypy will complain
|
||||
assert session.timeout != session.default_timeout
|
||||
|
||||
CustomSession.default_timeout = float(68.0_0)
|
||||
assert session.timeout == CustomSession.default_timeout
|
||||
|
||||
session.timeout = float(71.0_0)
|
||||
assert session.timeout != session.default_timeout
|
||||
del session.timeout
|
||||
CustomSession.default_timeout = session.default_timeout + 100
|
||||
assert (
|
||||
session.timeout != BaseSession.default_timeout
|
||||
and session.timeout == CustomSession.default_timeout
|
||||
)
|
||||
|
||||
def test_init_custom_api(self):
|
||||
api = TelegramAPIServer(
|
||||
base="http://example.com/{token}/{method}",
|
||||
file="http://example.com/{token}/file/{path{",
|
||||
file="http://example.com/{token}/file/{path}",
|
||||
)
|
||||
session = CustomSession(api=api)
|
||||
session = CustomSession()
|
||||
session.api = api
|
||||
assert session.api == api
|
||||
assert "example.com" in session.api.base
|
||||
|
||||
def test_prepare_value(self):
|
||||
session = CustomSession()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue