mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Merge remote-tracking branch 'origin/dev-3.x-download' into dev-3.x
This commit is contained in:
commit
a41bccddf9
5 changed files with 66 additions and 6 deletions
|
|
@ -1,8 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Callable, Optional, TypeVar, cast
|
from typing import AsyncGenerator, Callable, Optional, TypeVar, cast
|
||||||
|
|
||||||
from aiohttp import ClientSession, FormData
|
from aiohttp import ClientSession, ClientTimeout, FormData
|
||||||
|
|
||||||
from aiogram.api.methods import Request, TelegramMethod
|
from aiogram.api.methods import Request, TelegramMethod
|
||||||
|
|
||||||
|
|
@ -56,6 +56,16 @@ class AiohttpSession(BaseSession):
|
||||||
self.raise_for_status(response)
|
self.raise_for_status(response)
|
||||||
return cast(T, response.result)
|
return cast(T, response.result)
|
||||||
|
|
||||||
|
async def stream_content(
|
||||||
|
self, url: str, timeout: int, chunk_size: int
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
session = await self.create_session()
|
||||||
|
client_timeout = ClientTimeout(total=timeout)
|
||||||
|
|
||||||
|
async with session.get(url, timeout=client_timeout) as resp:
|
||||||
|
async for chunk in resp.content.iter_chunked(chunk_size):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def __aenter__(self) -> AiohttpSession:
|
async def __aenter__(self) -> AiohttpSession:
|
||||||
await self.create_session()
|
await self.create_session()
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
from aiogram.utils.exceptions import TelegramAPIError
|
from aiogram.utils.exceptions import TelegramAPIError
|
||||||
|
|
||||||
|
|
@ -44,6 +44,12 @@ class BaseSession(abc.ABC):
|
||||||
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
|
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def stream_content(
|
||||||
|
self, url: str, timeout: int, chunk_size: int
|
||||||
|
) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
||||||
|
yield b""
|
||||||
|
|
||||||
def prepare_value(self, value: Any) -> Union[str, int, bool]:
|
def prepare_value(self, value: Any) -> Union[str, int, bool]:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value
|
return value
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING, Deque, Optional, Type
|
from typing import TYPE_CHECKING, AsyncGenerator, Deque, Optional, Type
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.api.client.session.base import BaseSession
|
from aiogram.api.client.session.base import BaseSession
|
||||||
|
|
@ -29,6 +29,11 @@ class MockedSession(BaseSession):
|
||||||
self.raise_for_status(response)
|
self.raise_for_status(response)
|
||||||
return response.result # type: ignore
|
return response.result # type: ignore
|
||||||
|
|
||||||
|
async def stream_content(
|
||||||
|
self, url: str, timeout: int, chunk_size: int
|
||||||
|
) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
|
||||||
class MockedBot(Bot):
|
class MockedBot(Bot):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import AsyncContextManager
|
from typing import AsyncContextManager, AsyncGenerator
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -107,6 +107,26 @@ class TestAiohttpSession:
|
||||||
|
|
||||||
assert patched_raise_for_status.called_once()
|
assert patched_raise_for_status.called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_content(self, aresponses: ResponsesMockServer):
|
||||||
|
aresponses.add(
|
||||||
|
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = AiohttpSession()
|
||||||
|
stream = session.stream_content(
|
||||||
|
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=1
|
||||||
|
)
|
||||||
|
assert isinstance(stream, AsyncGenerator)
|
||||||
|
|
||||||
|
size = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
assert isinstance(chunk, bytes)
|
||||||
|
chunk_size = len(chunk)
|
||||||
|
assert chunk_size == 1
|
||||||
|
size += chunk_size
|
||||||
|
assert size == 10
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_context_manager(self):
|
async def test_context_manager(self):
|
||||||
session = AiohttpSession()
|
session = AiohttpSession()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
from typing import AsyncContextManager
|
from typing import AsyncContextManager, AsyncGenerator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -22,6 +22,14 @@ class CustomSession(BaseSession):
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
assert isinstance(method, TelegramMethod)
|
assert isinstance(method, TelegramMethod)
|
||||||
|
|
||||||
|
async def stream_content(
|
||||||
|
self, url: str, timeout: int, chunk_size: int
|
||||||
|
) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
||||||
|
assert isinstance(url, str)
|
||||||
|
assert isinstance(timeout, int)
|
||||||
|
assert isinstance(chunk_size, int)
|
||||||
|
yield b"\f" * 10
|
||||||
|
|
||||||
|
|
||||||
class TestBaseSession(DataMixin):
|
class TestBaseSession(DataMixin):
|
||||||
def test_init_api(self):
|
def test_init_api(self):
|
||||||
|
|
@ -100,6 +108,17 @@ class TestBaseSession(DataMixin):
|
||||||
|
|
||||||
assert await session.make_request("42:TEST", GetMe()) is None
|
assert await session.make_request("42:TEST", GetMe()) is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_content(self):
|
||||||
|
session = CustomSession()
|
||||||
|
stream = session.stream_content(
|
||||||
|
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=65536
|
||||||
|
)
|
||||||
|
assert isinstance(stream, AsyncGenerator)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
assert isinstance(chunk, bytes)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_context_manager(self):
|
async def test_context_manager(self):
|
||||||
session = CustomSession()
|
session = CustomSession()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue