Merge remote-tracking branch 'origin/dev-3.x-download' into dev-3.x

This commit is contained in:
Alex Root Junior 2020-02-02 22:53:55 +02:00
commit a41bccddf9
5 changed files with 66 additions and 6 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Optional, TypeVar, cast from typing import AsyncGenerator, Callable, Optional, TypeVar, cast
from aiohttp import ClientSession, FormData from aiohttp import ClientSession, ClientTimeout, FormData
from aiogram.api.methods import Request, TelegramMethod from aiogram.api.methods import Request, TelegramMethod
@ -56,6 +56,16 @@ class AiohttpSession(BaseSession):
self.raise_for_status(response) self.raise_for_status(response)
return cast(T, response.result) return cast(T, response.result)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk
async def __aenter__(self) -> AiohttpSession: async def __aenter__(self) -> AiohttpSession:
await self.create_session() await self.create_session()
return self return self

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import abc import abc
import datetime import datetime
import json import json
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union
from aiogram.utils.exceptions import TelegramAPIError from aiogram.utils.exceptions import TelegramAPIError
@ -44,6 +44,12 @@ class BaseSession(abc.ABC):
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
pass pass
@abc.abstractmethod
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
yield b""
def prepare_value(self, value: Any) -> Union[str, int, bool]: def prepare_value(self, value: Any) -> Union[str, int, bool]:
if isinstance(value, str): if isinstance(value, str):
return value return value

View file

@ -1,5 +1,5 @@
from collections import deque from collections import deque
from typing import TYPE_CHECKING, Deque, Optional, Type from typing import TYPE_CHECKING, AsyncGenerator, Deque, Optional, Type
from aiogram import Bot from aiogram import Bot
from aiogram.api.client.session.base import BaseSession from aiogram.api.client.session.base import BaseSession
@ -29,6 +29,11 @@ class MockedSession(BaseSession):
self.raise_for_status(response) self.raise_for_status(response)
return response.result # type: ignore return response.result # type: ignore
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
yield b""
class MockedBot(Bot): class MockedBot(Bot):
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -1,4 +1,4 @@
from typing import AsyncContextManager from typing import AsyncContextManager, AsyncGenerator
import aiohttp import aiohttp
import pytest import pytest
@ -107,6 +107,26 @@ class TestAiohttpSession:
assert patched_raise_for_status.called_once() assert patched_raise_for_status.called_once()
@pytest.mark.asyncio
async def test_stream_content(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=1
)
assert isinstance(stream, AsyncGenerator)
size = 0
async for chunk in stream:
assert isinstance(chunk, bytes)
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_context_manager(self): async def test_context_manager(self):
session = AiohttpSession() session = AiohttpSession()

View file

@ -1,5 +1,5 @@
import datetime import datetime
from typing import AsyncContextManager from typing import AsyncContextManager, AsyncGenerator
import pytest import pytest
@ -22,6 +22,14 @@ class CustomSession(BaseSession):
assert isinstance(token, str) assert isinstance(token, str)
assert isinstance(method, TelegramMethod) assert isinstance(method, TelegramMethod)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]: # pragma: no cover
assert isinstance(url, str)
assert isinstance(timeout, int)
assert isinstance(chunk_size, int)
yield b"\f" * 10
class TestBaseSession(DataMixin): class TestBaseSession(DataMixin):
def test_init_api(self): def test_init_api(self):
@ -100,6 +108,17 @@ class TestBaseSession(DataMixin):
assert await session.make_request("42:TEST", GetMe()) is None assert await session.make_request("42:TEST", GetMe()) is None
@pytest.mark.asyncio
async def test_stream_content(self):
session = CustomSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=65536
)
assert isinstance(stream, AsyncGenerator)
async for chunk in stream:
assert isinstance(chunk, bytes)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_context_manager(self): async def test_context_manager(self):
session = CustomSession() session = CustomSession()