mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Reworked graceful shutdown (#1124)
* Reworked graceful shutdown * Remove special errors from polling process * Update dependencies * Coverage * Added changelog
This commit is contained in:
parent
a332e88bc3
commit
d0b7135ca6
7 changed files with 197 additions and 50 deletions
2
CHANGES/1124.misc.rst
Normal file
2
CHANGES/1124.misc.rst
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
Reworked graceful shutdown. Added method to stop polling.
|
||||||
|
Now polling started from dispatcher can be stopped by signals gracefully without errors (on Linux and Mac).
|
||||||
2
Makefile
2
Makefile
|
|
@ -23,7 +23,7 @@ clean:
|
||||||
rm -rf *.egg-info
|
rm -rf *.egg-info
|
||||||
rm -f report.html
|
rm -f report.html
|
||||||
rm -f .coverage
|
rm -f .coverage
|
||||||
rm -rf {build,dist,site,.cache,.mypy_cache,reports}
|
rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports}
|
||||||
|
|
||||||
# =================================================================================================
|
# =================================================================================================
|
||||||
# Code quality
|
# Code quality
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import signal
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import CancelledError, Future, Lock
|
from asyncio import CancelledError, Event, Future, Lock
|
||||||
|
from contextlib import suppress
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from .. import loggers
|
from .. import loggers
|
||||||
|
|
@ -89,6 +91,8 @@ class Dispatcher(Router):
|
||||||
|
|
||||||
self.workflow_data: Dict[str, Any] = kwargs
|
self.workflow_data: Dict[str, Any] = kwargs
|
||||||
self._running_lock = Lock()
|
self._running_lock = Lock()
|
||||||
|
self._stop_signal = Event()
|
||||||
|
self._stopped_signal = Event()
|
||||||
|
|
||||||
def __getitem__(self, item: str) -> Any:
|
def __getitem__(self, item: str) -> Any:
|
||||||
return self.workflow_data[item]
|
return self.workflow_data[item]
|
||||||
|
|
@ -321,6 +325,11 @@ class Dispatcher(Router):
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
user: User = await bot.me()
|
||||||
|
loggers.dispatcher.info(
|
||||||
|
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||||
|
)
|
||||||
|
try:
|
||||||
async for update in self._listen_updates(
|
async for update in self._listen_updates(
|
||||||
bot,
|
bot,
|
||||||
polling_timeout=polling_timeout,
|
polling_timeout=polling_timeout,
|
||||||
|
|
@ -332,6 +341,10 @@ class Dispatcher(Router):
|
||||||
asyncio.create_task(handle_update)
|
asyncio.create_task(handle_update)
|
||||||
else:
|
else:
|
||||||
await handle_update
|
await handle_update
|
||||||
|
finally:
|
||||||
|
loggers.dispatcher.info(
|
||||||
|
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||||
|
)
|
||||||
|
|
||||||
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
@ -408,6 +421,24 @@ class Dispatcher(Router):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def stop_polling(self) -> None:
|
||||||
|
"""
|
||||||
|
Execute this method if you want to stop polling programmatically
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self._running_lock.locked():
|
||||||
|
raise RuntimeError("Polling is not started")
|
||||||
|
self._stop_signal.set()
|
||||||
|
await self._stopped_signal.wait()
|
||||||
|
|
||||||
|
def _signal_stop_polling(self, sig: signal.Signals) -> None:
|
||||||
|
if not self._running_lock.locked():
|
||||||
|
return
|
||||||
|
|
||||||
|
loggers.dispatcher.warning("Received %s signal", sig.name)
|
||||||
|
self._stop_signal.set()
|
||||||
|
|
||||||
async def start_polling(
|
async def start_polling(
|
||||||
self,
|
self,
|
||||||
*bots: Bot,
|
*bots: Bot,
|
||||||
|
|
@ -415,32 +446,54 @@ class Dispatcher(Router):
|
||||||
handle_as_tasks: bool = True,
|
handle_as_tasks: bool = True,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[List[str]] = None,
|
allowed_updates: Optional[List[str]] = None,
|
||||||
|
handle_signals: bool = True,
|
||||||
|
close_bot_session: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Polling runner
|
Polling runner
|
||||||
|
|
||||||
:param bots:
|
:param bots: Bot instances (one or mre)
|
||||||
:param polling_timeout:
|
:param polling_timeout: Long-polling wait time
|
||||||
:param handle_as_tasks:
|
:param handle_as_tasks: Run task for each event and no wait result
|
||||||
:param kwargs:
|
:param backoff_config: backoff-retry config
|
||||||
:param backoff_config:
|
:param allowed_updates: List of the update types you want your bot to receive
|
||||||
:param allowed_updates:
|
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||||
|
:param close_bot_session: close bot sessions on shutdown
|
||||||
|
:param kwargs: contextual data
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if not bots:
|
||||||
|
raise ValueError("At least one bot instance is required to start polling")
|
||||||
|
if "bot" in kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Keyword argument 'bot' is not acceptable, "
|
||||||
|
"the bot instance should be passed as positional argument"
|
||||||
|
)
|
||||||
|
|
||||||
async with self._running_lock: # Prevent to run this method twice at a once
|
async with self._running_lock: # Prevent to run this method twice at a once
|
||||||
|
self._stop_signal.clear()
|
||||||
|
self._stopped_signal.clear()
|
||||||
|
|
||||||
|
if handle_signals:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with suppress(NotImplementedError): # pragma: no cover
|
||||||
|
# Signals handling is not supported on Windows
|
||||||
|
# It also can't be covered on Windows
|
||||||
|
loop.add_signal_handler(
|
||||||
|
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
|
||||||
|
)
|
||||||
|
loop.add_signal_handler(
|
||||||
|
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
|
||||||
|
)
|
||||||
|
|
||||||
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
|
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
|
||||||
workflow_data.update(kwargs)
|
workflow_data.update(kwargs)
|
||||||
await self.emit_startup(**workflow_data)
|
await self.emit_startup(**workflow_data)
|
||||||
loggers.dispatcher.info("Start polling")
|
loggers.dispatcher.info("Start polling")
|
||||||
try:
|
try:
|
||||||
coro_list = []
|
tasks: List[asyncio.Task[Any]] = [
|
||||||
for bot in bots:
|
asyncio.create_task(
|
||||||
user: User = await bot.me()
|
|
||||||
loggers.dispatcher.info(
|
|
||||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
|
||||||
)
|
|
||||||
coro_list.append(
|
|
||||||
self._polling(
|
self._polling(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
handle_as_tasks=handle_as_tasks,
|
handle_as_tasks=handle_as_tasks,
|
||||||
|
|
@ -450,36 +503,53 @@ class Dispatcher(Router):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await asyncio.gather(*coro_list)
|
for bot in bots
|
||||||
|
]
|
||||||
|
tasks.append(asyncio.create_task(self._stop_signal.wait()))
|
||||||
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
for task in pending:
|
||||||
|
# (mostly) Graceful shutdown unfinished tasks
|
||||||
|
task.cancel()
|
||||||
|
with suppress(CancelledError):
|
||||||
|
await task
|
||||||
|
# Wait finished tasks to propagate unhandled exceptions
|
||||||
|
await asyncio.gather(*done)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
loggers.dispatcher.info("Polling stopped")
|
loggers.dispatcher.info("Polling stopped")
|
||||||
try:
|
try:
|
||||||
await self.emit_shutdown(**workflow_data)
|
await self.emit_shutdown(**workflow_data)
|
||||||
finally:
|
finally:
|
||||||
for bot in bots: # Close sessions
|
if close_bot_session:
|
||||||
await bot.session.close()
|
await asyncio.gather(*(bot.session.close() for bot in bots))
|
||||||
|
self._stopped_signal.set()
|
||||||
|
|
||||||
def run_polling(
|
def run_polling(
|
||||||
self,
|
self,
|
||||||
*bots: Bot,
|
*bots: Bot,
|
||||||
polling_timeout: int = 30,
|
polling_timeout: int = 10,
|
||||||
handle_as_tasks: bool = True,
|
handle_as_tasks: bool = True,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[List[str]] = None,
|
allowed_updates: Optional[List[str]] = None,
|
||||||
|
handle_signals: bool = True,
|
||||||
|
close_bot_session: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Run many bots with polling
|
Run many bots with polling
|
||||||
|
|
||||||
:param bots: Bot instances
|
:param bots: Bot instances (one or mre)
|
||||||
:param polling_timeout: Poling timeout
|
:param polling_timeout: Long-polling wait time
|
||||||
:param backoff_config:
|
|
||||||
:param handle_as_tasks: Run task for each event and no wait result
|
:param handle_as_tasks: Run task for each event and no wait result
|
||||||
|
:param backoff_config: backoff-retry config
|
||||||
:param allowed_updates: List of the update types you want your bot to receive
|
:param allowed_updates: List of the update types you want your bot to receive
|
||||||
|
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||||
|
:param close_bot_session: close bot sessions on shutdown
|
||||||
:param kwargs: contextual data
|
:param kwargs: contextual data
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
with suppress(KeyboardInterrupt):
|
||||||
return asyncio.run(
|
return asyncio.run(
|
||||||
self.start_polling(
|
self.start_polling(
|
||||||
*bots,
|
*bots,
|
||||||
|
|
@ -488,8 +558,7 @@ class Dispatcher(Router):
|
||||||
handle_as_tasks=handle_as_tasks,
|
handle_as_tasks=handle_as_tasks,
|
||||||
backoff_config=backoff_config,
|
backoff_config=backoff_config,
|
||||||
allowed_updates=allowed_updates,
|
allowed_updates=allowed_updates,
|
||||||
|
handle_signals=handle_signals,
|
||||||
|
close_bot_session=close_bot_session,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, SystemExit): # pragma: no cover
|
|
||||||
# Allow to graceful shutdown
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
|
||||||
"""
|
"""
|
||||||
if isinstance(handler, dict) and "handler" in handler:
|
if isinstance(handler, dict) and "handler" in handler:
|
||||||
handler = handler["handler"]
|
handler = handler["handler"]
|
||||||
if not hasattr(handler, "flags"):
|
if hasattr(handler, "flags"):
|
||||||
return {}
|
|
||||||
return handler.flags
|
return handler.flags
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_flag(
|
def get_flag(
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
|
||||||
super().__init__(message=message)
|
super().__init__(message=message)
|
||||||
self.method = method
|
self.method = method
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
original_message = super().__str__()
|
||||||
|
return f"Telegram server says {original_message}"
|
||||||
|
|
||||||
|
|
||||||
class TelegramNetworkError(TelegramAPIError):
|
class TelegramNetworkError(TelegramAPIError):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"magic-filter~=1.0.9",
|
"magic-filter~=1.0.9",
|
||||||
"aiohttp~=3.8.3",
|
"aiohttp~=3.8.4",
|
||||||
"pydantic~=1.10.4",
|
"pydantic~=1.10.4",
|
||||||
"aiofiles~=23.1.0",
|
"aiofiles~=23.1.0",
|
||||||
"certifi>=2022.9.24",
|
"certifi>=2022.9.24",
|
||||||
|
|
@ -56,7 +56,7 @@ fast = [
|
||||||
"uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'",
|
"uvloop>=0.17.0; (sys_platform == 'darwin' or sys_platform == 'linux') and platform_python_implementation != 'PyPy'",
|
||||||
]
|
]
|
||||||
redis = [
|
redis = [
|
||||||
"redis~=4.3.4",
|
"redis~=4.5.1",
|
||||||
]
|
]
|
||||||
proxy = [
|
proxy = [
|
||||||
"aiohttp-socks~=0.7.1",
|
"aiohttp-socks~=0.7.1",
|
||||||
|
|
@ -65,11 +65,11 @@ i18n = [
|
||||||
"Babel~=2.11.0",
|
"Babel~=2.11.0",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest~=7.1.3",
|
"pytest~=7.2.1",
|
||||||
"pytest-html~=3.1.1",
|
"pytest-html~=3.2.0",
|
||||||
"pytest-asyncio~=0.19.0",
|
"pytest-asyncio~=0.20.3",
|
||||||
"pytest-lazy-fixture~=0.6.3",
|
"pytest-lazy-fixture~=0.6.3",
|
||||||
"pytest-mock~=3.9.0",
|
"pytest-mock~=3.10.0",
|
||||||
"pytest-mypy~=0.10.0",
|
"pytest-mypy~=0.10.0",
|
||||||
"pytest-cov~=4.0.0",
|
"pytest-cov~=4.0.0",
|
||||||
"pytest-aiohttp~=1.0.4",
|
"pytest-aiohttp~=1.0.4",
|
||||||
|
|
@ -93,12 +93,12 @@ docs = [
|
||||||
dev = [
|
dev = [
|
||||||
"black~=23.1",
|
"black~=23.1",
|
||||||
"isort~=5.11",
|
"isort~=5.11",
|
||||||
"ruff~=0.0.245",
|
"ruff~=0.0.246",
|
||||||
"mypy~=0.981",
|
"mypy~=1.0.0",
|
||||||
"toml~=0.10.2",
|
"toml~=0.10.2",
|
||||||
"pre-commit~=2.20.0",
|
"pre-commit~=3.0.4",
|
||||||
"packaging~=23.0",
|
"packaging~=23.0",
|
||||||
"typing-extensions~=4.3.0",
|
"typing-extensions~=4.4.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import signal
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from asyncio import Event
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
@ -667,6 +669,17 @@ class TestDispatcher:
|
||||||
|
|
||||||
async def test_start_polling(self, bot: MockedBot):
|
async def test_start_polling(self, bot: MockedBot):
|
||||||
dispatcher = Dispatcher()
|
dispatcher = Dispatcher()
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="At least one bot instance is required to start polling"
|
||||||
|
):
|
||||||
|
await dispatcher.start_polling()
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Keyword argument 'bot' is not acceptable, "
|
||||||
|
"the bot instance should be passed as positional argument",
|
||||||
|
):
|
||||||
|
await dispatcher.start_polling(bot, bot=bot)
|
||||||
|
|
||||||
bot.add_result_for(
|
bot.add_result_for(
|
||||||
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
|
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
|
||||||
)
|
)
|
||||||
|
|
@ -690,6 +703,65 @@ class TestDispatcher:
|
||||||
mocked_process_update.assert_awaited()
|
mocked_process_update.assert_awaited()
|
||||||
mocked_emit_shutdown.assert_awaited()
|
mocked_emit_shutdown.assert_awaited()
|
||||||
|
|
||||||
|
async def test_stop_polling(self):
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await dispatcher.stop_polling()
|
||||||
|
|
||||||
|
assert not dispatcher._stop_signal.is_set()
|
||||||
|
assert not dispatcher._stopped_signal.is_set()
|
||||||
|
with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait:
|
||||||
|
async with dispatcher._running_lock:
|
||||||
|
await dispatcher.stop_polling()
|
||||||
|
assert dispatcher._stop_signal.is_set()
|
||||||
|
mocked_wait.assert_awaited()
|
||||||
|
|
||||||
|
async def test_signal_stop_polling(self):
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
with patch("asyncio.locks.Event.set") as mocked_set:
|
||||||
|
dispatcher._signal_stop_polling(signal.SIGINT)
|
||||||
|
mocked_set.assert_not_called()
|
||||||
|
|
||||||
|
async with dispatcher._running_lock:
|
||||||
|
dispatcher._signal_stop_polling(signal.SIGINT)
|
||||||
|
mocked_set.assert_called()
|
||||||
|
|
||||||
|
async def test_stop_polling_by_method(self, bot: MockedBot):
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
bot.add_result_for(
|
||||||
|
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
|
||||||
|
)
|
||||||
|
running = Event()
|
||||||
|
|
||||||
|
async def _mock_updates(*_):
|
||||||
|
running.set()
|
||||||
|
while True:
|
||||||
|
yield Update(update_id=42)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock
|
||||||
|
) as mocked_process_update, patch(
|
||||||
|
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates",
|
||||||
|
return_value=_mock_updates(),
|
||||||
|
):
|
||||||
|
task = asyncio.ensure_future(dispatcher.start_polling(bot))
|
||||||
|
await running.wait()
|
||||||
|
|
||||||
|
assert not dispatcher._stop_signal.is_set()
|
||||||
|
assert not dispatcher._stopped_signal.is_set()
|
||||||
|
|
||||||
|
await dispatcher.stop_polling()
|
||||||
|
assert dispatcher._stop_signal.is_set()
|
||||||
|
assert dispatcher._stopped_signal.is_set()
|
||||||
|
assert not task.exception()
|
||||||
|
|
||||||
|
mocked_process_update.assert_awaited()
|
||||||
|
|
||||||
|
@pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method")
|
||||||
|
async def test_stop_polling_by_signal(self, bot: MockedBot):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_run_polling(self, bot: MockedBot):
|
def test_run_polling(self, bot: MockedBot):
|
||||||
dispatcher = Dispatcher()
|
dispatcher = Dispatcher()
|
||||||
with patch(
|
with patch(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue