Reworked graceful shutdown (#1124)

* Reworked graceful shutdown

* Remove special errors from polling process

* Update dependencies

* Coverage

* Added changelog
This commit is contained in:
Alex Root Junior 2023-02-18 15:46:28 +02:00 committed by GitHub
parent a332e88bc3
commit d0b7135ca6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 50 deletions

2
CHANGES/1124.misc.rst Normal file
View file

@ -0,0 +1,2 @@
Reworked graceful shutdown. Added method to stop polling.
Now polling started from dispatcher can be stopped by signals gracefully without errors (on Linux and Mac).

View file

@ -23,7 +23,7 @@ clean:
rm -rf *.egg-info rm -rf *.egg-info
rm -f report.html rm -f report.html
rm -f .coverage rm -f .coverage
rm -rf {build,dist,site,.cache,.mypy_cache,reports} rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports}
# ================================================================================================= # =================================================================================================
# Code quality # Code quality

View file

@ -2,8 +2,10 @@ from __future__ import annotations
import asyncio import asyncio
import contextvars import contextvars
import signal
import warnings import warnings
from asyncio import CancelledError, Future, Lock from asyncio import CancelledError, Event, Future, Lock
from contextlib import suppress
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from .. import loggers from .. import loggers
@ -89,6 +91,8 @@ class Dispatcher(Router):
self.workflow_data: Dict[str, Any] = kwargs self.workflow_data: Dict[str, Any] = kwargs
self._running_lock = Lock() self._running_lock = Lock()
self._stop_signal = Event()
self._stopped_signal = Event()
def __getitem__(self, item: str) -> Any: def __getitem__(self, item: str) -> Any:
return self.workflow_data[item] return self.workflow_data[item]
@ -321,6 +325,11 @@ class Dispatcher(Router):
:param kwargs: :param kwargs:
:return: :return:
""" """
user: User = await bot.me()
loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
try:
async for update in self._listen_updates( async for update in self._listen_updates(
bot, bot,
polling_timeout=polling_timeout, polling_timeout=polling_timeout,
@ -332,6 +341,10 @@ class Dispatcher(Router):
asyncio.create_task(handle_update) asyncio.create_task(handle_update)
else: else:
await handle_update await handle_update
finally:
loggers.dispatcher.info(
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
""" """
@ -408,6 +421,24 @@ class Dispatcher(Router):
return None return None
async def stop_polling(self) -> None:
"""
Execute this method if you want to stop polling programmatically
:return:
"""
if not self._running_lock.locked():
raise RuntimeError("Polling is not started")
self._stop_signal.set()
await self._stopped_signal.wait()
def _signal_stop_polling(self, sig: signal.Signals) -> None:
if not self._running_lock.locked():
return
loggers.dispatcher.warning("Received %s signal", sig.name)
self._stop_signal.set()
async def start_polling( async def start_polling(
self, self,
*bots: Bot, *bots: Bot,
@ -415,32 +446,54 @@ class Dispatcher(Router):
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None, allowed_updates: Optional[List[str]] = None,
handle_signals: bool = True,
close_bot_session: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
Polling runner Polling runner
:param bots: :param bots: Bot instances (one or mre)
:param polling_timeout: :param polling_timeout: Long-polling wait time
:param handle_as_tasks: :param handle_as_tasks: Run task for each event and no wait result
:param kwargs: :param backoff_config: backoff-retry config
:param backoff_config: :param allowed_updates: List of the update types you want your bot to receive
:param allowed_updates: :param handle_signals: handle signals (SIGINT/SIGTERM)
:param close_bot_session: close bot sessions on shutdown
:param kwargs: contextual data
:return: :return:
""" """
if not bots:
raise ValueError("At least one bot instance is required to start polling")
if "bot" in kwargs:
raise ValueError(
"Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument"
)
async with self._running_lock: # Prevent to run this method twice at a once async with self._running_lock: # Prevent to run this method twice at a once
self._stop_signal.clear()
self._stopped_signal.clear()
if handle_signals:
loop = asyncio.get_running_loop()
with suppress(NotImplementedError): # pragma: no cover
# Signals handling is not supported on Windows
# It also can't be covered on Windows
loop.add_signal_handler(
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
)
loop.add_signal_handler(
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
)
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]} workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
workflow_data.update(kwargs) workflow_data.update(kwargs)
await self.emit_startup(**workflow_data) await self.emit_startup(**workflow_data)
loggers.dispatcher.info("Start polling") loggers.dispatcher.info("Start polling")
try: try:
coro_list = [] tasks: List[asyncio.Task[Any]] = [
for bot in bots: asyncio.create_task(
user: User = await bot.me()
loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
coro_list.append(
self._polling( self._polling(
bot=bot, bot=bot,
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
@ -450,36 +503,53 @@ class Dispatcher(Router):
**kwargs, **kwargs,
) )
) )
await asyncio.gather(*coro_list) for bot in bots
]
tasks.append(asyncio.create_task(self._stop_signal.wait()))
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
# (mostly) Graceful shutdown unfinished tasks
task.cancel()
with suppress(CancelledError):
await task
# Wait finished tasks to propagate unhandled exceptions
await asyncio.gather(*done)
finally: finally:
loggers.dispatcher.info("Polling stopped") loggers.dispatcher.info("Polling stopped")
try: try:
await self.emit_shutdown(**workflow_data) await self.emit_shutdown(**workflow_data)
finally: finally:
for bot in bots: # Close sessions if close_bot_session:
await bot.session.close() await asyncio.gather(*(bot.session.close() for bot in bots))
self._stopped_signal.set()
def run_polling( def run_polling(
self, self,
*bots: Bot, *bots: Bot,
polling_timeout: int = 30, polling_timeout: int = 10,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None, allowed_updates: Optional[List[str]] = None,
handle_signals: bool = True,
close_bot_session: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
Run many bots with polling Run many bots with polling
:param bots: Bot instances :param bots: Bot instances (one or mre)
:param polling_timeout: Poling timeout :param polling_timeout: Long-polling wait time
:param backoff_config:
:param handle_as_tasks: Run task for each event and no wait result :param handle_as_tasks: Run task for each event and no wait result
:param backoff_config: backoff-retry config
:param allowed_updates: List of the update types you want your bot to receive :param allowed_updates: List of the update types you want your bot to receive
:param handle_signals: handle signals (SIGINT/SIGTERM)
:param close_bot_session: close bot sessions on shutdown
:param kwargs: contextual data :param kwargs: contextual data
:return: :return:
""" """
try: with suppress(KeyboardInterrupt):
return asyncio.run( return asyncio.run(
self.start_polling( self.start_polling(
*bots, *bots,
@ -488,8 +558,7 @@ class Dispatcher(Router):
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config, backoff_config=backoff_config,
allowed_updates=allowed_updates, allowed_updates=allowed_updates,
handle_signals=handle_signals,
close_bot_session=close_bot_session,
) )
) )
except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown
pass

View file

@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
""" """
if isinstance(handler, dict) and "handler" in handler: if isinstance(handler, dict) and "handler" in handler:
handler = handler["handler"] handler = handler["handler"]
if not hasattr(handler, "flags"): if hasattr(handler, "flags"):
return {}
return handler.flags return handler.flags
return {}
def get_flag( def get_flag(

View file

@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
super().__init__(message=message) super().__init__(message=message)
self.method = method self.method = method
def __str__(self) -> str:
original_message = super().__str__()
return f"Telegram server says {original_message}"
class TelegramNetworkError(TelegramAPIError): class TelegramNetworkError(TelegramAPIError):
pass pass

View file

@ -41,7 +41,7 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"magic-filter~=1.0.9", "magic-filter~=1.0.9",
"aiohttp~=3.8.3", "aiohttp~=3.8.4",
"pydantic~=1.10.4", "pydantic~=1.10.4",
"aiofiles~=23.1.0", "aiofiles~=23.1.0",
"certifi>=2022.9.24", "certifi>=2022.9.24",
@ -56,7 +56,7 @@ fast = [
"uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'", "uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'",
] ]
redis = [ redis = [
"redis~=4.3.4", "redis~=4.5.1",
] ]
proxy = [ proxy = [
"aiohttp-socks~=0.7.1", "aiohttp-socks~=0.7.1",
@ -65,11 +65,11 @@ i18n = [
"Babel~=2.11.0", "Babel~=2.11.0",
] ]
test = [ test = [
"pytest~=7.1.3", "pytest~=7.2.1",
"pytest-html~=3.1.1", "pytest-html~=3.2.0",
"pytest-asyncio~=0.19.0", "pytest-asyncio~=0.20.3",
"pytest-lazy-fixture~=0.6.3", "pytest-lazy-fixture~=0.6.3",
"pytest-mock~=3.9.0", "pytest-mock~=3.10.0",
"pytest-mypy~=0.10.0", "pytest-mypy~=0.10.0",
"pytest-cov~=4.0.0", "pytest-cov~=4.0.0",
"pytest-aiohttp~=1.0.4", "pytest-aiohttp~=1.0.4",
@ -93,12 +93,12 @@ docs = [
dev = [ dev = [
"black~=23.1", "black~=23.1",
"isort~=5.11", "isort~=5.11",
"ruff~=0.0.245", "ruff~=0.0.246",
"mypy~=0.981", "mypy~=1.0.0",
"toml~=0.10.2", "toml~=0.10.2",
"pre-commit~=2.20.0", "pre-commit~=3.0.4",
"packaging~=23.0", "packaging~=23.0",
"typing-extensions~=4.3.0", "typing-extensions~=4.4.0",
] ]
[project.urls] [project.urls]

View file

@ -1,7 +1,9 @@
import asyncio import asyncio
import datetime import datetime
import signal
import time import time
import warnings import warnings
from asyncio import Event
from collections import Counter from collections import Counter
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -667,6 +669,17 @@ class TestDispatcher:
async def test_start_polling(self, bot: MockedBot): async def test_start_polling(self, bot: MockedBot):
dispatcher = Dispatcher() dispatcher = Dispatcher()
with pytest.raises(
ValueError, match="At least one bot instance is required to start polling"
):
await dispatcher.start_polling()
with pytest.raises(
ValueError,
match="Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument",
):
await dispatcher.start_polling(bot, bot=bot)
bot.add_result_for( bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
) )
@ -690,6 +703,65 @@ class TestDispatcher:
mocked_process_update.assert_awaited() mocked_process_update.assert_awaited()
mocked_emit_shutdown.assert_awaited() mocked_emit_shutdown.assert_awaited()
async def test_stop_polling(self):
dispatcher = Dispatcher()
with pytest.raises(RuntimeError):
await dispatcher.stop_polling()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait:
async with dispatcher._running_lock:
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
mocked_wait.assert_awaited()
async def test_signal_stop_polling(self):
dispatcher = Dispatcher()
with patch("asyncio.locks.Event.set") as mocked_set:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_not_called()
async with dispatcher._running_lock:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_called()
async def test_stop_polling_by_method(self, bot: MockedBot):
dispatcher = Dispatcher()
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
running = Event()
async def _mock_updates(*_):
running.set()
while True:
yield Update(update_id=42)
await asyncio.sleep(1)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock
) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates",
return_value=_mock_updates(),
):
task = asyncio.ensure_future(dispatcher.start_polling(bot))
await running.wait()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
assert dispatcher._stopped_signal.is_set()
assert not task.exception()
mocked_process_update.assert_awaited()
@pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method")
async def test_stop_polling_by_signal(self, bot: MockedBot):
pass
def test_run_polling(self, bot: MockedBot): def test_run_polling(self, bot: MockedBot):
dispatcher = Dispatcher() dispatcher = Dispatcher()
with patch( with patch(