Reworked graceful shutdown (#1124)

* Reworked graceful shutdown

* Remove special errors from polling process

* Update dependencies

* Coverage

* Added changelog
This commit is contained in:
Alex Root Junior 2023-02-18 15:46:28 +02:00 committed by GitHub
parent a332e88bc3
commit d0b7135ca6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 50 deletions

2
CHANGES/1124.misc.rst Normal file
View file

@ -0,0 +1,2 @@
Reworked graceful shutdown. Added method to stop polling.
Now polling started from dispatcher can be stopped by signals gracefully without errors (on Linux and Mac).

View file

@ -23,7 +23,7 @@ clean:
rm -rf *.egg-info
rm -f report.html
rm -f .coverage
rm -rf {build,dist,site,.cache,.mypy_cache,reports}
rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports}
# =================================================================================================
# Code quality

View file

@ -2,8 +2,10 @@ from __future__ import annotations
import asyncio
import contextvars
import signal
import warnings
from asyncio import CancelledError, Future, Lock
from asyncio import CancelledError, Event, Future, Lock
from contextlib import suppress
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from .. import loggers
@ -89,6 +91,8 @@ class Dispatcher(Router):
self.workflow_data: Dict[str, Any] = kwargs
self._running_lock = Lock()
self._stop_signal = Event()
self._stopped_signal = Event()
def __getitem__(self, item: str) -> Any:
return self.workflow_data[item]
@ -321,6 +325,11 @@ class Dispatcher(Router):
:param kwargs:
:return:
"""
user: User = await bot.me()
loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
try:
async for update in self._listen_updates(
bot,
polling_timeout=polling_timeout,
@ -332,6 +341,10 @@ class Dispatcher(Router):
asyncio.create_task(handle_update)
else:
await handle_update
finally:
loggers.dispatcher.info(
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
@ -408,6 +421,24 @@ class Dispatcher(Router):
return None
async def stop_polling(self) -> None:
"""
Execute this method if you want to stop polling programmatically
:return:
"""
if not self._running_lock.locked():
raise RuntimeError("Polling is not started")
self._stop_signal.set()
await self._stopped_signal.wait()
def _signal_stop_polling(self, sig: signal.Signals) -> None:
if not self._running_lock.locked():
return
loggers.dispatcher.warning("Received %s signal", sig.name)
self._stop_signal.set()
async def start_polling(
self,
*bots: Bot,
@ -415,32 +446,54 @@ class Dispatcher(Router):
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
handle_signals: bool = True,
close_bot_session: bool = True,
**kwargs: Any,
) -> None:
"""
Polling runner
:param bots:
:param polling_timeout:
:param handle_as_tasks:
:param kwargs:
:param backoff_config:
:param allowed_updates:
:param bots: Bot instances (one or mre)
:param polling_timeout: Long-polling wait time
:param handle_as_tasks: Run task for each event and no wait result
:param backoff_config: backoff-retry config
:param allowed_updates: List of the update types you want your bot to receive
:param handle_signals: handle signals (SIGINT/SIGTERM)
:param close_bot_session: close bot sessions on shutdown
:param kwargs: contextual data
:return:
"""
if not bots:
raise ValueError("At least one bot instance is required to start polling")
if "bot" in kwargs:
raise ValueError(
"Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument"
)
async with self._running_lock: # Prevent to run this method twice at a once
self._stop_signal.clear()
self._stopped_signal.clear()
if handle_signals:
loop = asyncio.get_running_loop()
with suppress(NotImplementedError): # pragma: no cover
# Signals handling is not supported on Windows
# It also can't be covered on Windows
loop.add_signal_handler(
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
)
loop.add_signal_handler(
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
)
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
workflow_data.update(kwargs)
await self.emit_startup(**workflow_data)
loggers.dispatcher.info("Start polling")
try:
coro_list = []
for bot in bots:
user: User = await bot.me()
loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
coro_list.append(
tasks: List[asyncio.Task[Any]] = [
asyncio.create_task(
self._polling(
bot=bot,
handle_as_tasks=handle_as_tasks,
@ -450,36 +503,53 @@ class Dispatcher(Router):
**kwargs,
)
)
await asyncio.gather(*coro_list)
for bot in bots
]
tasks.append(asyncio.create_task(self._stop_signal.wait()))
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
# (mostly) Graceful shutdown unfinished tasks
task.cancel()
with suppress(CancelledError):
await task
# Wait finished tasks to propagate unhandled exceptions
await asyncio.gather(*done)
finally:
loggers.dispatcher.info("Polling stopped")
try:
await self.emit_shutdown(**workflow_data)
finally:
for bot in bots: # Close sessions
await bot.session.close()
if close_bot_session:
await asyncio.gather(*(bot.session.close() for bot in bots))
self._stopped_signal.set()
def run_polling(
self,
*bots: Bot,
polling_timeout: int = 30,
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
handle_signals: bool = True,
close_bot_session: bool = True,
**kwargs: Any,
) -> None:
"""
Run many bots with polling
:param bots: Bot instances
:param polling_timeout: Poling timeout
:param backoff_config:
:param bots: Bot instances (one or mre)
:param polling_timeout: Long-polling wait time
:param handle_as_tasks: Run task for each event and no wait result
:param backoff_config: backoff-retry config
:param allowed_updates: List of the update types you want your bot to receive
:param handle_signals: handle signals (SIGINT/SIGTERM)
:param close_bot_session: close bot sessions on shutdown
:param kwargs: contextual data
:return:
"""
try:
with suppress(KeyboardInterrupt):
return asyncio.run(
self.start_polling(
*bots,
@ -488,8 +558,7 @@ class Dispatcher(Router):
handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
handle_signals=handle_signals,
close_bot_session=close_bot_session,
)
)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown
pass

View file

@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
"""
if isinstance(handler, dict) and "handler" in handler:
handler = handler["handler"]
if not hasattr(handler, "flags"):
return {}
if hasattr(handler, "flags"):
return handler.flags
return {}
def get_flag(

View file

@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
super().__init__(message=message)
self.method = method
def __str__(self) -> str:
original_message = super().__str__()
return f"Telegram server says {original_message}"
class TelegramNetworkError(TelegramAPIError):
pass

View file

@ -41,7 +41,7 @@ classifiers = [
]
dependencies = [
"magic-filter~=1.0.9",
"aiohttp~=3.8.3",
"aiohttp~=3.8.4",
"pydantic~=1.10.4",
"aiofiles~=23.1.0",
"certifi>=2022.9.24",
@ -56,7 +56,7 @@ fast = [
"uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'",
]
redis = [
"redis~=4.3.4",
"redis~=4.5.1",
]
proxy = [
"aiohttp-socks~=0.7.1",
@ -65,11 +65,11 @@ i18n = [
"Babel~=2.11.0",
]
test = [
"pytest~=7.1.3",
"pytest-html~=3.1.1",
"pytest-asyncio~=0.19.0",
"pytest~=7.2.1",
"pytest-html~=3.2.0",
"pytest-asyncio~=0.20.3",
"pytest-lazy-fixture~=0.6.3",
"pytest-mock~=3.9.0",
"pytest-mock~=3.10.0",
"pytest-mypy~=0.10.0",
"pytest-cov~=4.0.0",
"pytest-aiohttp~=1.0.4",
@ -93,12 +93,12 @@ docs = [
dev = [
"black~=23.1",
"isort~=5.11",
"ruff~=0.0.245",
"mypy~=0.981",
"ruff~=0.0.246",
"mypy~=1.0.0",
"toml~=0.10.2",
"pre-commit~=2.20.0",
"pre-commit~=3.0.4",
"packaging~=23.0",
"typing-extensions~=4.3.0",
"typing-extensions~=4.4.0",
]
[project.urls]

View file

@ -1,7 +1,9 @@
import asyncio
import datetime
import signal
import time
import warnings
from asyncio import Event
from collections import Counter
from typing import Any
from unittest.mock import AsyncMock, patch
@ -667,6 +669,17 @@ class TestDispatcher:
async def test_start_polling(self, bot: MockedBot):
dispatcher = Dispatcher()
with pytest.raises(
ValueError, match="At least one bot instance is required to start polling"
):
await dispatcher.start_polling()
with pytest.raises(
ValueError,
match="Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument",
):
await dispatcher.start_polling(bot, bot=bot)
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
@ -690,6 +703,65 @@ class TestDispatcher:
mocked_process_update.assert_awaited()
mocked_emit_shutdown.assert_awaited()
async def test_stop_polling(self):
dispatcher = Dispatcher()
with pytest.raises(RuntimeError):
await dispatcher.stop_polling()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait:
async with dispatcher._running_lock:
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
mocked_wait.assert_awaited()
async def test_signal_stop_polling(self):
dispatcher = Dispatcher()
with patch("asyncio.locks.Event.set") as mocked_set:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_not_called()
async with dispatcher._running_lock:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_called()
async def test_stop_polling_by_method(self, bot: MockedBot):
dispatcher = Dispatcher()
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
running = Event()
async def _mock_updates(*_):
running.set()
while True:
yield Update(update_id=42)
await asyncio.sleep(1)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock
) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates",
return_value=_mock_updates(),
):
task = asyncio.ensure_future(dispatcher.start_polling(bot))
await running.wait()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
assert dispatcher._stopped_signal.is_set()
assert not task.exception()
mocked_process_update.assert_awaited()
@pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method")
async def test_stop_polling_by_signal(self, bot: MockedBot):
pass
def test_run_polling(self, bot: MockedBot):
dispatcher = Dispatcher()
with patch(