diff --git a/CHANGES/1124.misc.rst b/CHANGES/1124.misc.rst new file mode 100644 index 00000000..4a6e6221 --- /dev/null +++ b/CHANGES/1124.misc.rst @@ -0,0 +1,2 @@ +Reworked graceful shutdown. Added method to stop polling. +Now polling started from dispatcher can be stopped by signals gracefully without errors (on Linux and Mac). diff --git a/Makefile b/Makefile index cbe5c333..5696616b 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ clean: rm -rf *.egg-info rm -f report.html rm -f .coverage - rm -rf {build,dist,site,.cache,.mypy_cache,reports} + rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports} # ================================================================================================= # Code quality diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 7f058ec2..63eb9a3a 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio import contextvars +import signal import warnings -from asyncio import CancelledError, Future, Lock +from asyncio import CancelledError, Event, Future, Lock +from contextlib import suppress from typing import Any, AsyncGenerator, Dict, List, Optional, Union from .. import loggers @@ -89,6 +91,8 @@ class Dispatcher(Router): self.workflow_data: Dict[str, Any] = kwargs self._running_lock = Lock() + self._stop_signal = Event() + self._stopped_signal = Event() def __getitem__(self, item: str) -> Any: return self.workflow_data[item] @@ -321,17 +325,26 @@ class Dispatcher(Router): :param kwargs: :return: """ - async for update in self._listen_updates( - bot, - polling_timeout=polling_timeout, - backoff_config=backoff_config, - allowed_updates=allowed_updates, - ): - handle_update = self._process_update(bot=bot, update=update, **kwargs) - if handle_as_tasks: - asyncio.create_task(handle_update) - else: - await handle_update + user: User = await bot.me() + loggers.dispatcher.info( + "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name + ) + try: + async for update in self._listen_updates( + bot, + polling_timeout=polling_timeout, + backoff_config=backoff_config, + allowed_updates=allowed_updates, + ): + handle_update = self._process_update(bot=bot, update=update, **kwargs) + if handle_as_tasks: + asyncio.create_task(handle_update) + else: + await handle_update + finally: + loggers.dispatcher.info( + "Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name + ) async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: """ @@ -408,6 +421,24 @@ class Dispatcher(Router): return None + async def stop_polling(self) -> None: + """ + Execute this method if you want to stop polling programmatically + + :return: + """ + if not self._running_lock.locked(): + raise RuntimeError("Polling is not started") + self._stop_signal.set() + await self._stopped_signal.wait() + + def _signal_stop_polling(self, sig: signal.Signals) -> None: + if not self._running_lock.locked(): + return + + loggers.dispatcher.warning("Received %s signal", sig.name) + self._stop_signal.set() + async def start_polling( self, *bots: Bot, @@ -415,32 +446,54 @@ class Dispatcher(Router): handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, allowed_updates: Optional[List[str]] = None, + handle_signals: bool = True, + close_bot_session: bool = True, **kwargs: Any, ) -> None: """ Polling runner - :param bots: - :param polling_timeout: - :param handle_as_tasks: - :param kwargs: - :param backoff_config: - :param allowed_updates: + :param bots: Bot instances (one or mre) + :param polling_timeout: Long-polling wait time + :param handle_as_tasks: Run task for each event and no wait result + :param backoff_config: backoff-retry config + :param allowed_updates: List of the update types you want your bot to receive + :param handle_signals: handle signals (SIGINT/SIGTERM) + :param close_bot_session: close bot sessions on shutdown + :param kwargs: contextual data :return: """ + if not bots: + raise ValueError("At least one bot instance is required to start polling") + if "bot" in kwargs: + raise ValueError( + "Keyword argument 'bot' is not acceptable, " + "the bot instance should be passed as positional argument" + ) + async with self._running_lock: # Prevent to run this method twice at a once + self._stop_signal.clear() + self._stopped_signal.clear() + + if handle_signals: + loop = asyncio.get_running_loop() + with suppress(NotImplementedError): # pragma: no cover + # Signals handling is not supported on Windows + # It also can't be covered on Windows + loop.add_signal_handler( + signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM + ) + loop.add_signal_handler( + signal.SIGINT, self._signal_stop_polling, signal.SIGINT + ) + workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]} workflow_data.update(kwargs) await self.emit_startup(**workflow_data) loggers.dispatcher.info("Start polling") try: - coro_list = [] - for bot in bots: - user: User = await bot.me() - loggers.dispatcher.info( - "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name - ) - coro_list.append( + tasks: List[asyncio.Task[Any]] = [ + asyncio.create_task( self._polling( bot=bot, handle_as_tasks=handle_as_tasks, @@ -450,36 +503,53 @@ class Dispatcher(Router): **kwargs, ) ) - await asyncio.gather(*coro_list) + for bot in bots + ] + tasks.append(asyncio.create_task(self._stop_signal.wait())) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in pending: + # (mostly) Graceful shutdown unfinished tasks + task.cancel() + with suppress(CancelledError): + await task + # Wait finished tasks to propagate unhandled exceptions + await asyncio.gather(*done) + finally: loggers.dispatcher.info("Polling stopped") try: await self.emit_shutdown(**workflow_data) finally: - for bot in bots: # Close sessions - await bot.session.close() + if close_bot_session: + await asyncio.gather(*(bot.session.close() for bot in bots)) + self._stopped_signal.set() def run_polling( self, *bots: Bot, - polling_timeout: int = 30, + polling_timeout: int = 10, handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, allowed_updates: Optional[List[str]] = None, + handle_signals: bool = True, + close_bot_session: bool = True, **kwargs: Any, ) -> None: """ Run many bots with polling - :param bots: Bot instances - :param polling_timeout: Poling timeout - :param backoff_config: + :param bots: Bot instances (one or mre) + :param polling_timeout: Long-polling wait time :param handle_as_tasks: Run task for each event and no wait result + :param backoff_config: backoff-retry config :param allowed_updates: List of the update types you want your bot to receive + :param handle_signals: handle signals (SIGINT/SIGTERM) + :param close_bot_session: close bot sessions on shutdown :param kwargs: contextual data :return: """ - try: + with suppress(KeyboardInterrupt): return asyncio.run( self.start_polling( *bots, @@ -488,8 +558,7 @@ class Dispatcher(Router): handle_as_tasks=handle_as_tasks, backoff_config=backoff_config, allowed_updates=allowed_updates, + handle_signals=handle_signals, + close_bot_session=close_bot_session, ) ) - except (KeyboardInterrupt, SystemExit): # pragma: no cover - # Allow to graceful shutdown - pass diff --git a/aiogram/dispatcher/flags.py b/aiogram/dispatcher/flags.py index 82cdbc04..aad8a29f 100644 --- a/aiogram/dispatcher/flags.py +++ b/aiogram/dispatcher/flags.py @@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, """ if isinstance(handler, dict) and "handler" in handler: handler = handler["handler"] - if not hasattr(handler, "flags"): - return {} - return handler.flags + if hasattr(handler, "flags"): + return handler.flags + return {} def get_flag( diff --git a/aiogram/exceptions.py b/aiogram/exceptions.py index dcf433c6..1c1e59fb 100644 --- a/aiogram/exceptions.py +++ b/aiogram/exceptions.py @@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError): super().__init__(message=message) self.method = method + def __str__(self) -> str: + original_message = super().__str__() + return f"Telegram server says {original_message}" + class TelegramNetworkError(TelegramAPIError): pass diff --git a/pyproject.toml b/pyproject.toml index 0c2d3a5e..45809b52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ ] dependencies = [ "magic-filter~=1.0.9", - "aiohttp~=3.8.3", + "aiohttp~=3.8.4", "pydantic~=1.10.4", "aiofiles~=23.1.0", "certifi>=2022.9.24", @@ -56,7 +56,7 @@ fast = [ "uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'", ] redis = [ - "redis~=4.3.4", + "redis~=4.5.1", ] proxy = [ "aiohttp-socks~=0.7.1", @@ -65,11 +65,11 @@ i18n = [ "Babel~=2.11.0", ] test = [ - "pytest~=7.1.3", - "pytest-html~=3.1.1", - "pytest-asyncio~=0.19.0", + "pytest~=7.2.1", + "pytest-html~=3.2.0", + "pytest-asyncio~=0.20.3", "pytest-lazy-fixture~=0.6.3", - "pytest-mock~=3.9.0", + "pytest-mock~=3.10.0", "pytest-mypy~=0.10.0", "pytest-cov~=4.0.0", "pytest-aiohttp~=1.0.4", @@ -93,12 +93,12 @@ docs = [ dev = [ "black~=23.1", "isort~=5.11", - "ruff~=0.0.245", - "mypy~=0.981", + "ruff~=0.0.246", + "mypy~=1.0.0", "toml~=0.10.2", - "pre-commit~=2.20.0", + "pre-commit~=3.0.4", "packaging~=23.0", - "typing-extensions~=4.3.0", + "typing-extensions~=4.4.0", ] [project.urls] diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 6d6fa1ab..509bff35 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -1,7 +1,9 @@ import asyncio import datetime +import signal import time import warnings +from asyncio import Event from collections import Counter from typing import Any from unittest.mock import AsyncMock, patch @@ -667,6 +669,17 @@ class TestDispatcher: async def test_start_polling(self, bot: MockedBot): dispatcher = Dispatcher() + with pytest.raises( + ValueError, match="At least one bot instance is required to start polling" + ): + await dispatcher.start_polling() + with pytest.raises( + ValueError, + match="Keyword argument 'bot' is not acceptable, " + "the bot instance should be passed as positional argument", + ): + await dispatcher.start_polling(bot, bot=bot) + bot.add_result_for( GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") ) @@ -690,6 +703,65 @@ class TestDispatcher: mocked_process_update.assert_awaited() mocked_emit_shutdown.assert_awaited() + async def test_stop_polling(self): + dispatcher = Dispatcher() + with pytest.raises(RuntimeError): + await dispatcher.stop_polling() + + assert not dispatcher._stop_signal.is_set() + assert not dispatcher._stopped_signal.is_set() + with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait: + async with dispatcher._running_lock: + await dispatcher.stop_polling() + assert dispatcher._stop_signal.is_set() + mocked_wait.assert_awaited() + + async def test_signal_stop_polling(self): + dispatcher = Dispatcher() + with patch("asyncio.locks.Event.set") as mocked_set: + dispatcher._signal_stop_polling(signal.SIGINT) + mocked_set.assert_not_called() + + async with dispatcher._running_lock: + dispatcher._signal_stop_polling(signal.SIGINT) + mocked_set.assert_called() + + async def test_stop_polling_by_method(self, bot: MockedBot): + dispatcher = Dispatcher() + bot.add_result_for( + GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") + ) + running = Event() + + async def _mock_updates(*_): + running.set() + while True: + yield Update(update_id=42) + await asyncio.sleep(1) + + with patch( + "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock + ) as mocked_process_update, patch( + "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates", + return_value=_mock_updates(), + ): + task = asyncio.ensure_future(dispatcher.start_polling(bot)) + await running.wait() + + assert not dispatcher._stop_signal.is_set() + assert not dispatcher._stopped_signal.is_set() + + await dispatcher.stop_polling() + assert dispatcher._stop_signal.is_set() + assert dispatcher._stopped_signal.is_set() + assert not task.exception() + + mocked_process_update.assert_awaited() + + @pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method") + async def test_stop_polling_by_signal(self, bot: MockedBot): + pass + def test_run_polling(self, bot: MockedBot): dispatcher = Dispatcher() with patch(