mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 18:01:04 +00:00
Merge pull request #318 from aiogram/dev-3.x-refactor-sessions
Dev 3.x refactor sessions
This commit is contained in:
commit
05dd42712d
7 changed files with 151 additions and 44 deletions
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
|
@ -15,11 +14,11 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from aiohttp import BasicAuth, ClientSession, ClientTimeout, FormData, TCPConnector
|
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector
|
||||||
|
|
||||||
from aiogram.api.methods import Request, TelegramMethod
|
from aiogram.api.methods import Request, TelegramMethod
|
||||||
|
|
||||||
from .base import PRODUCTION, BaseSession, TelegramAPIServer
|
from .base import BaseSession
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
|
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
|
||||||
|
|
@ -72,34 +71,42 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
|
||||||
|
|
||||||
|
|
||||||
class AiohttpSession(BaseSession):
|
class AiohttpSession(BaseSession):
|
||||||
def __init__(
|
def __init__(self, proxy: Optional[_ProxyType] = None):
|
||||||
self,
|
|
||||||
api: TelegramAPIServer = PRODUCTION,
|
|
||||||
json_loads: Optional[Callable[..., Any]] = None,
|
|
||||||
json_dumps: Optional[Callable[..., str]] = None,
|
|
||||||
proxy: Optional[_ProxyType] = None,
|
|
||||||
):
|
|
||||||
super(AiohttpSession, self).__init__(
|
|
||||||
api=api, json_loads=json_loads, json_dumps=json_dumps, proxy=proxy
|
|
||||||
)
|
|
||||||
self._session: Optional[ClientSession] = None
|
self._session: Optional[ClientSession] = None
|
||||||
self._connector_type: Type[TCPConnector] = TCPConnector
|
self._connector_type: Type[TCPConnector] = TCPConnector
|
||||||
self._connector_init: Dict[str, Any] = {}
|
self._connector_init: Dict[str, Any] = {}
|
||||||
|
self._should_reset_connector = True # flag determines connector state
|
||||||
|
self._proxy: Optional[_ProxyType] = None
|
||||||
|
|
||||||
if self.proxy:
|
if proxy is not None:
|
||||||
try:
|
try:
|
||||||
self._connector_type, self._connector_init = _prepare_connector(
|
self._setup_proxy_connector(proxy)
|
||||||
cast(_ProxyType, self.proxy)
|
|
||||||
)
|
|
||||||
except ImportError as exc: # pragma: no cover
|
except ImportError as exc: # pragma: no cover
|
||||||
raise UserWarning(
|
raise UserWarning(
|
||||||
"In order to use aiohttp client for proxy requests, install "
|
"In order to use aiohttp client for proxy requests, install "
|
||||||
"https://pypi.org/project/aiohttp-socks/"
|
"https://pypi.org/project/aiohttp-socks/"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
|
||||||
|
self._connector_type, self._connector_init = _prepare_connector(proxy)
|
||||||
|
self._proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def proxy(self) -> Optional[_ProxyType]:
|
||||||
|
return self._proxy
|
||||||
|
|
||||||
|
@proxy.setter
|
||||||
|
def proxy(self, proxy: _ProxyType) -> None:
|
||||||
|
self._setup_proxy_connector(proxy)
|
||||||
|
self._should_reset_connector = True
|
||||||
|
|
||||||
async def create_session(self) -> ClientSession:
|
async def create_session(self) -> ClientSession:
|
||||||
|
if self._should_reset_connector:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
if self._session is None or self._session.closed:
|
if self._session is None or self._session.closed:
|
||||||
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
|
self._session = ClientSession(connector=self._connector_type(**self._connector_init))
|
||||||
|
self._should_reset_connector = False
|
||||||
|
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
|
|
@ -125,7 +132,9 @@ class AiohttpSession(BaseSession):
|
||||||
url = self.api.api_url(token=token, method=request.method)
|
url = self.api.api_url(token=token, method=request.method)
|
||||||
form = self.build_form_data(request)
|
form = self.build_form_data(request)
|
||||||
|
|
||||||
async with session.post(url, data=form) as resp:
|
async with session.post(
|
||||||
|
url, data=form, timeout=call.request_timeout or self.timeout
|
||||||
|
) as resp:
|
||||||
raw_result = await resp.json(loads=self.json_loads)
|
raw_result = await resp.json(loads=self.json_loads)
|
||||||
|
|
||||||
response = call.build_response(raw_result)
|
response = call.build_response(raw_result)
|
||||||
|
|
@ -136,9 +145,8 @@ class AiohttpSession(BaseSession):
|
||||||
self, url: str, timeout: int, chunk_size: int
|
self, url: str, timeout: int, chunk_size: int
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
session = await self.create_session()
|
session = await self.create_session()
|
||||||
client_timeout = ClientTimeout(total=timeout)
|
|
||||||
|
|
||||||
async with session.get(url, timeout=client_timeout) as resp:
|
async with session.get(url, timeout=timeout) as resp:
|
||||||
async for chunk in resp.content.iter_chunked(chunk_size):
|
async for chunk in resp.content.iter_chunked(chunk_size):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union
|
from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
from aiogram.utils.exceptions import TelegramAPIError
|
from aiogram.utils.exceptions import TelegramAPIError
|
||||||
|
|
||||||
|
|
@ -12,30 +12,58 @@ from ...methods import Response, TelegramMethod
|
||||||
from ..telegram import PRODUCTION, TelegramAPIServer
|
from ..telegram import PRODUCTION, TelegramAPIServer
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
PT = TypeVar("PT")
|
_JsonLoads = Callable[..., Any]
|
||||||
|
_JsonDumps = Callable[..., str]
|
||||||
|
|
||||||
|
|
||||||
class BaseSession(abc.ABC):
|
class BaseSession(abc.ABC):
|
||||||
def __init__(
|
# global session timeout
|
||||||
self,
|
default_timeout: ClassVar[float] = 60.0
|
||||||
api: Optional[TelegramAPIServer] = None,
|
|
||||||
json_loads: Optional[Callable[..., Any]] = None,
|
|
||||||
json_dumps: Optional[Callable[..., str]] = None,
|
|
||||||
proxy: Optional[PT] = None,
|
|
||||||
) -> None:
|
|
||||||
if api is None:
|
|
||||||
api = PRODUCTION
|
|
||||||
if json_loads is None:
|
|
||||||
json_loads = json.loads
|
|
||||||
if json_dumps is None:
|
|
||||||
json_dumps = json.dumps
|
|
||||||
|
|
||||||
self.api = api
|
_api: TelegramAPIServer
|
||||||
self.json_loads = json_loads
|
_json_loads: _JsonLoads
|
||||||
self.json_dumps = json_dumps
|
_json_dumps: _JsonDumps
|
||||||
self.proxy = proxy
|
_timeout: float
|
||||||
|
|
||||||
def raise_for_status(self, response: Response[T]) -> None:
|
@property
|
||||||
|
def api(self) -> TelegramAPIServer:
|
||||||
|
return getattr(self, "_api", PRODUCTION) # type: ignore
|
||||||
|
|
||||||
|
@api.setter
|
||||||
|
def api(self, value: TelegramAPIServer) -> None:
|
||||||
|
self._api = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def json_loads(self) -> _JsonLoads:
|
||||||
|
return getattr(self, "_json_loads", json.loads) # type: ignore
|
||||||
|
|
||||||
|
@json_loads.setter
|
||||||
|
def json_loads(self, value: _JsonLoads) -> None:
|
||||||
|
self._json_loads = value # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def json_dumps(self) -> _JsonDumps:
|
||||||
|
return getattr(self, "_json_dumps", json.dumps) # type: ignore
|
||||||
|
|
||||||
|
@json_dumps.setter
|
||||||
|
def json_dumps(self, value: _JsonDumps) -> None:
|
||||||
|
self._json_dumps = value # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timeout(self) -> float:
|
||||||
|
return getattr(self, "_timeout", self.__class__.default_timeout) # type: ignore
|
||||||
|
|
||||||
|
@timeout.setter
|
||||||
|
def timeout(self, value: float) -> None:
|
||||||
|
self._timeout = value
|
||||||
|
|
||||||
|
@timeout.deleter
|
||||||
|
def timeout(self) -> None:
|
||||||
|
if hasattr(self, "_timeout"):
|
||||||
|
del self._timeout
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def raise_for_status(cls, response: Response[T]) -> None:
|
||||||
if response.ok:
|
if response.ok:
|
||||||
return
|
return
|
||||||
raise TelegramAPIError(response.description)
|
raise TelegramAPIError(response.description)
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,16 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[T]):
|
||||||
def build_request(self) -> Request: # pragma: no cover
|
def build_request(self) -> Request: # pragma: no cover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
request_timeout: Optional[float] = None
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Any:
|
||||||
|
# override dict of pydantic.BaseModel to overcome exporting request_timeout field
|
||||||
|
exclude = kwargs.pop("exclude", set())
|
||||||
|
if isinstance(exclude, set):
|
||||||
|
exclude.add("request_timeout")
|
||||||
|
|
||||||
|
return super().dict(exclude=exclude, **kwargs)
|
||||||
|
|
||||||
def build_response(self, data: Dict[str, Any]) -> Response[T]:
|
def build_response(self, data: Dict[str, Any]) -> Response[T]:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return Response[self.__returning__](**data) # type: ignore
|
return Response[self.__returning__](**data) # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Aiohttp session
|
# Aiohttp session
|
||||||
|
|
||||||
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp]('https://pypi.org/project/aiohttp/')
|
AiohttpSession represents a wrapper-class around `ClientSession` from [aiohttp](https://pypi.org/project/aiohttp/ "PyPi repository"){target=_blank}
|
||||||
|
|
||||||
Currently `AiohttpSession` is a default session used in `aiogram.Bot`
|
Currently `AiohttpSession` is a default session used in `aiogram.Bot`
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ Bot('token', session=session)
|
||||||
|
|
||||||
## Proxy requests in AiohttpSession
|
## Proxy requests in AiohttpSession
|
||||||
|
|
||||||
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks]('https://pypi.org/project/aiohttp-socks/')
|
In order to use AiohttpSession with proxy connector you have to install [aiohttp-socks](https://pypi.org/project/aiohttp-socks/ "PyPi repository"){target=_blank}
|
||||||
|
|
||||||
Binding session to bot:
|
Binding session to bot:
|
||||||
```python
|
```python
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ markdown_extensions:
|
||||||
- pymdownx.inlinehilite
|
- pymdownx.inlinehilite
|
||||||
- markdown_include.include:
|
- markdown_include.include:
|
||||||
base_path: docs
|
base_path: docs
|
||||||
|
- attr_list
|
||||||
|
|
||||||
nav:
|
nav:
|
||||||
- index.md
|
- index.md
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,22 @@ class TestAiohttpSession:
|
||||||
aiohttp_session = await session.create_session()
|
aiohttp_session = await session.create_session()
|
||||||
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
|
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_connector(self):
|
||||||
|
session = AiohttpSession()
|
||||||
|
assert session._should_reset_connector
|
||||||
|
await session.create_session()
|
||||||
|
assert session._should_reset_connector is False
|
||||||
|
await session.close()
|
||||||
|
assert session._should_reset_connector is False
|
||||||
|
|
||||||
|
assert session.proxy is None
|
||||||
|
session.proxy = "socks5://auth:auth@proxy.url/"
|
||||||
|
assert session._should_reset_connector
|
||||||
|
await session.create_session()
|
||||||
|
assert session._should_reset_connector is False
|
||||||
|
await session.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_close_session(self):
|
async def test_close_session(self):
|
||||||
session = AiohttpSession()
|
session = AiohttpSession()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
from typing import AsyncContextManager, AsyncGenerator
|
from typing import AsyncContextManager, AsyncGenerator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -35,13 +36,56 @@ class TestBaseSession:
|
||||||
session = CustomSession()
|
session = CustomSession()
|
||||||
assert session.api == PRODUCTION
|
assert session.api == PRODUCTION
|
||||||
|
|
||||||
|
def test_default_props(self):
|
||||||
|
session = CustomSession()
|
||||||
|
assert session.api == PRODUCTION
|
||||||
|
assert session.json_loads == json.loads
|
||||||
|
assert session.json_dumps == json.dumps
|
||||||
|
|
||||||
|
def custom_loads(*_):
|
||||||
|
return json.loads
|
||||||
|
|
||||||
|
def custom_dumps(*_):
|
||||||
|
return json.dumps
|
||||||
|
|
||||||
|
session.json_dumps = custom_dumps
|
||||||
|
assert session.json_dumps == custom_dumps == session._json_dumps
|
||||||
|
session.json_loads = custom_loads
|
||||||
|
assert session.json_loads == custom_loads == session._json_loads
|
||||||
|
|
||||||
|
different_session = CustomSession()
|
||||||
|
assert all(
|
||||||
|
not hasattr(different_session, attr) for attr in ("_json_loads", "_json_dumps", "_api")
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_timeout(self):
|
||||||
|
session = CustomSession()
|
||||||
|
assert session.timeout == session.default_timeout == CustomSession.default_timeout
|
||||||
|
|
||||||
|
session.default_timeout = float(65.0_0) # mypy will complain
|
||||||
|
assert session.timeout != session.default_timeout
|
||||||
|
|
||||||
|
CustomSession.default_timeout = float(68.0_0)
|
||||||
|
assert session.timeout == CustomSession.default_timeout
|
||||||
|
|
||||||
|
session.timeout = float(71.0_0)
|
||||||
|
assert session.timeout != session.default_timeout
|
||||||
|
del session.timeout
|
||||||
|
CustomSession.default_timeout = session.default_timeout + 100
|
||||||
|
assert (
|
||||||
|
session.timeout != BaseSession.default_timeout
|
||||||
|
and session.timeout == CustomSession.default_timeout
|
||||||
|
)
|
||||||
|
|
||||||
def test_init_custom_api(self):
|
def test_init_custom_api(self):
|
||||||
api = TelegramAPIServer(
|
api = TelegramAPIServer(
|
||||||
base="http://example.com/{token}/{method}",
|
base="http://example.com/{token}/{method}",
|
||||||
file="http://example.com/{token}/file/{path{",
|
file="http://example.com/{token}/file/{path}",
|
||||||
)
|
)
|
||||||
session = CustomSession(api=api)
|
session = CustomSession()
|
||||||
|
session.api = api
|
||||||
assert session.api == api
|
assert session.api == api
|
||||||
|
assert "example.com" in session.api.base
|
||||||
|
|
||||||
def test_prepare_value(self):
|
def test_prepare_value(self):
|
||||||
session = CustomSession()
|
session = CustomSession()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue