Implement stream_content in AiohttpSession and add tests

This commit is contained in:
gabbhack 2020-01-22 22:55:34 +05:00
parent 7ab0db7991
commit 26708154b0
2 changed files with 33 additions and 3 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Optional, TypeVar, cast from typing import AsyncGenerator, Callable, Optional, TypeVar, cast
from aiohttp import ClientSession, FormData from aiohttp import ClientSession, ClientTimeout, FormData
from aiogram.api.methods import Request, TelegramMethod from aiogram.api.methods import Request, TelegramMethod
@ -56,6 +56,16 @@ class AiohttpSession(BaseSession):
self.raise_for_status(response) self.raise_for_status(response)
return cast(T, response.result) return cast(T, response.result)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
client_timeout = ClientTimeout(total=timeout)
async with session.get(url, timeout=client_timeout) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk
async def __aenter__(self) -> AiohttpSession: async def __aenter__(self) -> AiohttpSession:
await self.create_session() await self.create_session()
return self return self

View file

@ -1,4 +1,4 @@
from typing import AsyncContextManager from typing import AsyncContextManager, AsyncGenerator
import aiohttp import aiohttp
import pytest import pytest
@ -107,6 +107,26 @@ class TestAiohttpSession:
assert patched_raise_for_status.called_once() assert patched_raise_for_status.called_once()
@pytest.mark.asyncio
async def test_stream_content(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=1
)
assert isinstance(stream, AsyncGenerator)
size = 0
async for chunk in stream:
assert isinstance(chunk, bytes)
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_context_manager(self): async def test_context_manager(self):
session = AiohttpSession() session = AiohttpSession()