[3.x] Check status code when downloading file (#1079)

* Check status code when downloading file and raise an error if someting bad happends

* Style fixes

* Add doc

* Use "towncrier create <issue>.<type>" for creating file
This commit is contained in:
Dmitry Anfimov 2023-02-12 06:56:11 +06:00 committed by GitHub
parent 94e11ce8e9
commit 184ee1fbf8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 45 additions and 7 deletions

1
CHANGES/816.bugfix Normal file
View file

@ -0,0 +1 @@
Check status code when downloading file

View file

@ -304,7 +304,12 @@ class Bot(ContextInstanceMixin["Bot"]):
close_stream = True
else:
url = self.session.api.file_url(self.__token, file_path)
stream = self.session.stream_content(url=url, timeout=timeout, chunk_size=chunk_size)
stream = self.session.stream_content(
url=url,
timeout=timeout,
chunk_size=chunk_size,
raise_for_status=True,
)
try:
if isinstance(destination, (str, pathlib.Path)):

View file

@ -162,11 +162,11 @@ class AiohttpSession(BaseSession):
return cast(TelegramType, response.result)
async def stream_content(
self, url: str, timeout: int, chunk_size: int
self, url: str, timeout: int, chunk_size: int, raise_for_status: bool
) -> AsyncGenerator[bytes, None]:
session = await self.create_session()
async with session.get(url, timeout=timeout) as resp:
async with session.get(url, timeout=timeout, raise_for_status=raise_for_status) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk

View file

@ -153,7 +153,7 @@ class BaseSession(abc.ABC):
@abc.abstractmethod
async def stream_content(
self, url: str, timeout: int, chunk_size: int
self, url: str, timeout: int, chunk_size: int, raise_for_status: bool
) -> AsyncGenerator[bytes, None]: # pragma: no cover
"""
Stream reader

View file

@ -135,7 +135,10 @@ class URLInputFile(InputFile):
bot = Bot.get_current(no_error=False)
stream = bot.session.stream_content(
url=self.url, timeout=self.timeout, chunk_size=self.chunk_size
url=self.url,
timeout=self.timeout,
chunk_size=self.chunk_size,
raise_for_status=True,
)
async for chunk in stream:

View file

@ -37,7 +37,11 @@ class MockedSession(BaseSession):
return response.result # type: ignore
async def stream_content(
self, url: str, timeout: int, chunk_size: int
self,
url: str,
timeout: int,
chunk_size: int,
raise_for_status: bool,
) -> AsyncGenerator[bytes, None]: # pragma: no cover
yield b""

View file

@ -190,7 +190,10 @@ class TestAiohttpSession:
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png", timeout=5, chunk_size=1
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
assert isinstance(stream, AsyncGenerator)
@ -202,6 +205,28 @@ class TestAiohttpSession:
size += chunk_size
assert size == 10
async def test_stream_content_404(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY,
aresponses.ANY,
"get",
aresponses.Response(
status=404,
body=b"File not found",
),
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
with pytest.raises(ClientError):
async for _ in stream:
...
async def test_context_manager(self):
session = AiohttpSession()
assert isinstance(session, AsyncContextManager)