mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Reworked graceful shutdown (#1124)
* Reworked graceful shutdown * Remove special errors from polling process * Update dependencies * Coverage * Added changelog
This commit is contained in:
parent
a332e88bc3
commit
d0b7135ca6
7 changed files with 197 additions and 50 deletions
2
CHANGES/1124.misc.rst
Normal file
2
CHANGES/1124.misc.rst
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
Reworked graceful shutdown. Added method to stop polling.
|
||||
Now polling started from dispatcher can be stopped by signals gracefully without errors (on Linux and Mac).
|
||||
2
Makefile
2
Makefile
|
|
@ -23,7 +23,7 @@ clean:
|
|||
rm -rf *.egg-info
|
||||
rm -f report.html
|
||||
rm -f .coverage
|
||||
rm -rf {build,dist,site,.cache,.mypy_cache,reports}
|
||||
rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports}
|
||||
|
||||
# =================================================================================================
|
||||
# Code quality
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import signal
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Lock
|
||||
from asyncio import CancelledError, Event, Future, Lock
|
||||
from contextlib import suppress
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from .. import loggers
|
||||
|
|
@ -89,6 +91,8 @@ class Dispatcher(Router):
|
|||
|
||||
self.workflow_data: Dict[str, Any] = kwargs
|
||||
self._running_lock = Lock()
|
||||
self._stop_signal = Event()
|
||||
self._stopped_signal = Event()
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self.workflow_data[item]
|
||||
|
|
@ -321,6 +325,11 @@ class Dispatcher(Router):
|
|||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
user: User = await bot.me()
|
||||
loggers.dispatcher.info(
|
||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
try:
|
||||
async for update in self._listen_updates(
|
||||
bot,
|
||||
polling_timeout=polling_timeout,
|
||||
|
|
@ -332,6 +341,10 @@ class Dispatcher(Router):
|
|||
asyncio.create_task(handle_update)
|
||||
else:
|
||||
await handle_update
|
||||
finally:
|
||||
loggers.dispatcher.info(
|
||||
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
|
||||
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
|
@ -408,6 +421,24 @@ class Dispatcher(Router):
|
|||
|
||||
return None
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""
|
||||
Execute this method if you want to stop polling programmatically
|
||||
|
||||
:return:
|
||||
"""
|
||||
if not self._running_lock.locked():
|
||||
raise RuntimeError("Polling is not started")
|
||||
self._stop_signal.set()
|
||||
await self._stopped_signal.wait()
|
||||
|
||||
def _signal_stop_polling(self, sig: signal.Signals) -> None:
|
||||
if not self._running_lock.locked():
|
||||
return
|
||||
|
||||
loggers.dispatcher.warning("Received %s signal", sig.name)
|
||||
self._stop_signal.set()
|
||||
|
||||
async def start_polling(
|
||||
self,
|
||||
*bots: Bot,
|
||||
|
|
@ -415,32 +446,54 @@ class Dispatcher(Router):
|
|||
handle_as_tasks: bool = True,
|
||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||
allowed_updates: Optional[List[str]] = None,
|
||||
handle_signals: bool = True,
|
||||
close_bot_session: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Polling runner
|
||||
|
||||
:param bots:
|
||||
:param polling_timeout:
|
||||
:param handle_as_tasks:
|
||||
:param kwargs:
|
||||
:param backoff_config:
|
||||
:param allowed_updates:
|
||||
:param bots: Bot instances (one or mre)
|
||||
:param polling_timeout: Long-polling wait time
|
||||
:param handle_as_tasks: Run task for each event and no wait result
|
||||
:param backoff_config: backoff-retry config
|
||||
:param allowed_updates: List of the update types you want your bot to receive
|
||||
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||
:param close_bot_session: close bot sessions on shutdown
|
||||
:param kwargs: contextual data
|
||||
:return:
|
||||
"""
|
||||
if not bots:
|
||||
raise ValueError("At least one bot instance is required to start polling")
|
||||
if "bot" in kwargs:
|
||||
raise ValueError(
|
||||
"Keyword argument 'bot' is not acceptable, "
|
||||
"the bot instance should be passed as positional argument"
|
||||
)
|
||||
|
||||
async with self._running_lock: # Prevent to run this method twice at a once
|
||||
self._stop_signal.clear()
|
||||
self._stopped_signal.clear()
|
||||
|
||||
if handle_signals:
|
||||
loop = asyncio.get_running_loop()
|
||||
with suppress(NotImplementedError): # pragma: no cover
|
||||
# Signals handling is not supported on Windows
|
||||
# It also can't be covered on Windows
|
||||
loop.add_signal_handler(
|
||||
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
|
||||
)
|
||||
loop.add_signal_handler(
|
||||
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
|
||||
)
|
||||
|
||||
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
|
||||
workflow_data.update(kwargs)
|
||||
await self.emit_startup(**workflow_data)
|
||||
loggers.dispatcher.info("Start polling")
|
||||
try:
|
||||
coro_list = []
|
||||
for bot in bots:
|
||||
user: User = await bot.me()
|
||||
loggers.dispatcher.info(
|
||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
coro_list.append(
|
||||
tasks: List[asyncio.Task[Any]] = [
|
||||
asyncio.create_task(
|
||||
self._polling(
|
||||
bot=bot,
|
||||
handle_as_tasks=handle_as_tasks,
|
||||
|
|
@ -450,36 +503,53 @@ class Dispatcher(Router):
|
|||
**kwargs,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*coro_list)
|
||||
for bot in bots
|
||||
]
|
||||
tasks.append(asyncio.create_task(self._stop_signal.wait()))
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in pending:
|
||||
# (mostly) Graceful shutdown unfinished tasks
|
||||
task.cancel()
|
||||
with suppress(CancelledError):
|
||||
await task
|
||||
# Wait finished tasks to propagate unhandled exceptions
|
||||
await asyncio.gather(*done)
|
||||
|
||||
finally:
|
||||
loggers.dispatcher.info("Polling stopped")
|
||||
try:
|
||||
await self.emit_shutdown(**workflow_data)
|
||||
finally:
|
||||
for bot in bots: # Close sessions
|
||||
await bot.session.close()
|
||||
if close_bot_session:
|
||||
await asyncio.gather(*(bot.session.close() for bot in bots))
|
||||
self._stopped_signal.set()
|
||||
|
||||
def run_polling(
|
||||
self,
|
||||
*bots: Bot,
|
||||
polling_timeout: int = 30,
|
||||
polling_timeout: int = 10,
|
||||
handle_as_tasks: bool = True,
|
||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||
allowed_updates: Optional[List[str]] = None,
|
||||
handle_signals: bool = True,
|
||||
close_bot_session: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Run many bots with polling
|
||||
|
||||
:param bots: Bot instances
|
||||
:param polling_timeout: Poling timeout
|
||||
:param backoff_config:
|
||||
:param bots: Bot instances (one or mre)
|
||||
:param polling_timeout: Long-polling wait time
|
||||
:param handle_as_tasks: Run task for each event and no wait result
|
||||
:param backoff_config: backoff-retry config
|
||||
:param allowed_updates: List of the update types you want your bot to receive
|
||||
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||
:param close_bot_session: close bot sessions on shutdown
|
||||
:param kwargs: contextual data
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
with suppress(KeyboardInterrupt):
|
||||
return asyncio.run(
|
||||
self.start_polling(
|
||||
*bots,
|
||||
|
|
@ -488,8 +558,7 @@ class Dispatcher(Router):
|
|||
handle_as_tasks=handle_as_tasks,
|
||||
backoff_config=backoff_config,
|
||||
allowed_updates=allowed_updates,
|
||||
handle_signals=handle_signals,
|
||||
close_bot_session=close_bot_session,
|
||||
)
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit): # pragma: no cover
|
||||
# Allow to graceful shutdown
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
|
|||
"""
|
||||
if isinstance(handler, dict) and "handler" in handler:
|
||||
handler = handler["handler"]
|
||||
if not hasattr(handler, "flags"):
|
||||
return {}
|
||||
if hasattr(handler, "flags"):
|
||||
return handler.flags
|
||||
return {}
|
||||
|
||||
|
||||
def get_flag(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
|
|||
super().__init__(message=message)
|
||||
self.method = method
|
||||
|
||||
def __str__(self) -> str:
|
||||
original_message = super().__str__()
|
||||
return f"Telegram server says {original_message}"
|
||||
|
||||
|
||||
class TelegramNetworkError(TelegramAPIError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ classifiers = [
|
|||
]
|
||||
dependencies = [
|
||||
"magic-filter~=1.0.9",
|
||||
"aiohttp~=3.8.3",
|
||||
"aiohttp~=3.8.4",
|
||||
"pydantic~=1.10.4",
|
||||
"aiofiles~=23.1.0",
|
||||
"certifi>=2022.9.24",
|
||||
|
|
@ -56,7 +56,7 @@ fast = [
|
|||
"uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'",
|
||||
]
|
||||
redis = [
|
||||
"redis~=4.3.4",
|
||||
"redis~=4.5.1",
|
||||
]
|
||||
proxy = [
|
||||
"aiohttp-socks~=0.7.1",
|
||||
|
|
@ -65,11 +65,11 @@ i18n = [
|
|||
"Babel~=2.11.0",
|
||||
]
|
||||
test = [
|
||||
"pytest~=7.1.3",
|
||||
"pytest-html~=3.1.1",
|
||||
"pytest-asyncio~=0.19.0",
|
||||
"pytest~=7.2.1",
|
||||
"pytest-html~=3.2.0",
|
||||
"pytest-asyncio~=0.20.3",
|
||||
"pytest-lazy-fixture~=0.6.3",
|
||||
"pytest-mock~=3.9.0",
|
||||
"pytest-mock~=3.10.0",
|
||||
"pytest-mypy~=0.10.0",
|
||||
"pytest-cov~=4.0.0",
|
||||
"pytest-aiohttp~=1.0.4",
|
||||
|
|
@ -93,12 +93,12 @@ docs = [
|
|||
dev = [
|
||||
"black~=23.1",
|
||||
"isort~=5.11",
|
||||
"ruff~=0.0.245",
|
||||
"mypy~=0.981",
|
||||
"ruff~=0.0.246",
|
||||
"mypy~=1.0.0",
|
||||
"toml~=0.10.2",
|
||||
"pre-commit~=2.20.0",
|
||||
"pre-commit~=3.0.4",
|
||||
"packaging~=23.0",
|
||||
"typing-extensions~=4.3.0",
|
||||
"typing-extensions~=4.4.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import signal
|
||||
import time
|
||||
import warnings
|
||||
from asyncio import Event
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
|
@ -667,6 +669,17 @@ class TestDispatcher:
|
|||
|
||||
async def test_start_polling(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
with pytest.raises(
|
||||
ValueError, match="At least one bot instance is required to start polling"
|
||||
):
|
||||
await dispatcher.start_polling()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Keyword argument 'bot' is not acceptable, "
|
||||
"the bot instance should be passed as positional argument",
|
||||
):
|
||||
await dispatcher.start_polling(bot, bot=bot)
|
||||
|
||||
bot.add_result_for(
|
||||
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
|
||||
)
|
||||
|
|
@ -690,6 +703,65 @@ class TestDispatcher:
|
|||
mocked_process_update.assert_awaited()
|
||||
mocked_emit_shutdown.assert_awaited()
|
||||
|
||||
async def test_stop_polling(self):
|
||||
dispatcher = Dispatcher()
|
||||
with pytest.raises(RuntimeError):
|
||||
await dispatcher.stop_polling()
|
||||
|
||||
assert not dispatcher._stop_signal.is_set()
|
||||
assert not dispatcher._stopped_signal.is_set()
|
||||
with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait:
|
||||
async with dispatcher._running_lock:
|
||||
await dispatcher.stop_polling()
|
||||
assert dispatcher._stop_signal.is_set()
|
||||
mocked_wait.assert_awaited()
|
||||
|
||||
async def test_signal_stop_polling(self):
|
||||
dispatcher = Dispatcher()
|
||||
with patch("asyncio.locks.Event.set") as mocked_set:
|
||||
dispatcher._signal_stop_polling(signal.SIGINT)
|
||||
mocked_set.assert_not_called()
|
||||
|
||||
async with dispatcher._running_lock:
|
||||
dispatcher._signal_stop_polling(signal.SIGINT)
|
||||
mocked_set.assert_called()
|
||||
|
||||
async def test_stop_polling_by_method(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
bot.add_result_for(
|
||||
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
|
||||
)
|
||||
running = Event()
|
||||
|
||||
async def _mock_updates(*_):
|
||||
running.set()
|
||||
while True:
|
||||
yield Update(update_id=42)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock
|
||||
) as mocked_process_update, patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates",
|
||||
return_value=_mock_updates(),
|
||||
):
|
||||
task = asyncio.ensure_future(dispatcher.start_polling(bot))
|
||||
await running.wait()
|
||||
|
||||
assert not dispatcher._stop_signal.is_set()
|
||||
assert not dispatcher._stopped_signal.is_set()
|
||||
|
||||
await dispatcher.stop_polling()
|
||||
assert dispatcher._stop_signal.is_set()
|
||||
assert dispatcher._stopped_signal.is_set()
|
||||
assert not task.exception()
|
||||
|
||||
mocked_process_update.assert_awaited()
|
||||
|
||||
@pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method")
|
||||
async def test_stop_polling_by_signal(self, bot: MockedBot):
|
||||
pass
|
||||
|
||||
def test_run_polling(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
with patch(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue