Merge remote-tracking branch 'origin/dev-3.x-download' into dev-3.x

This commit is contained in:
Alex Root Junior 2020-02-02 22:53:55 +02:00
commit a41bccddf9
5 changed files with 66 additions and 6 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Callable, Optional, TypeVar, cast
from typing import AsyncGenerator, Callable, Optional, TypeVar, cast
from aiohttp import ClientSession, FormData
from aiohttp import ClientSession, ClientTimeout, FormData
from aiogram.api.methods import Request, TelegramMethod
@ -56,6 +56,16 @@ class AiohttpSession(BaseSession):
self.raise_for_status(response)
return cast(T, response.result)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk
async def __aenter__(self) -> AiohttpSession:
await self.create_session()
return self

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import datetime
import json
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union
from aiogram.utils.exceptions import TelegramAPIError
@ -44,6 +44,12 @@ class BaseSession(abc.ABC):
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
pass
@abc.abstractmethod
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
yield b""
def prepare_value(self, value: Any) -> Union[str, int, bool]:
if isinstance(value, str):
return value

View file

@ -1,5 +1,5 @@
from collections import deque
from typing import TYPE_CHECKING, Deque, Optional, Type
from typing import TYPE_CHECKING, AsyncGenerator, Deque, Optional, Type
from aiogram import Bot
from aiogram.api.client.session.base import BaseSession
@ -29,6 +29,11 @@ class MockedSession(BaseSession):
self.raise_for_status(response)
return response.result # type: ignore
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
yield b""
class MockedBot(Bot):
if TYPE_CHECKING:

View file

@ -1,4 +1,4 @@
from typing import AsyncContextManager
from typing import AsyncContextManager, AsyncGenerator
import aiohttp
import pytest
@ -107,6 +107,26 @@ class TestAiohttpSession:
assert patched_raise_for_status.called_once()
@pytest.mark.asyncio
async def test_stream_content(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=1
)
assert isinstance(stream, AsyncGenerator)
size = 0
async for chunk in stream:
assert isinstance(chunk, bytes)
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
@pytest.mark.asyncio
async def test_context_manager(self):
session = AiohttpSession()

View file

@ -1,5 +1,5 @@
import datetime
from typing import AsyncContextManager
from typing import AsyncContextManager, AsyncGenerator
import pytest
@ -22,6 +22,14 @@ class CustomSession(BaseSession):
assert isinstance(token, str)
assert isinstance(method, TelegramMethod)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
assert isinstance(url, str)
assert isinstance(timeout, int)
assert isinstance(chunk_size, int)
yield b"\f" * 10
class TestBaseSession(DataMixin):
def test_init_api(self):
@ -100,6 +108,17 @@ class TestBaseSession(DataMixin):
assert await session.make_request("42:TEST", GetMe()) is None
@pytest.mark.asyncio
async def test_stream_content(self):
session = CustomSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=65536
)
assert isinstance(stream, AsyncGenerator)
async for chunk in stream:
assert isinstance(chunk, bytes)
@pytest.mark.asyncio
async def test_context_manager(self):
session = CustomSession()