Make endless long-polling

This commit is contained in:
Alex Root Junior 2021-06-19 01:16:51 +03:00
parent 5296724a0f
commit ac1f0efde8
7 changed files with 245 additions and 22 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -14,11 +15,12 @@ from typing import (
cast, cast,
) )
from aiohttp import BasicAuth, ClientSession, FormData, TCPConnector from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
from aiogram.methods import Request, TelegramMethod from aiogram.methods import Request, TelegramMethod
from ...methods.base import TelegramType from ...methods.base import TelegramType
from ...utils.exceptions.network import NetworkError
from .base import UNSET, BaseSession from .base import UNSET, BaseSession
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@ -139,11 +141,15 @@ class AiohttpSession(BaseSession):
url = self.api.api_url(token=bot.token, method=request.method) url = self.api.api_url(token=bot.token, method=request.method)
form = self.build_form_data(request) form = self.build_form_data(request)
async with session.post( try:
url, data=form, timeout=self.timeout if timeout is None else timeout async with session.post(
) as resp: url, data=form, timeout=self.timeout if timeout is None else timeout
raw_result = await resp.text() ) as resp:
raw_result = await resp.text()
except asyncio.TimeoutError:
raise NetworkError(method=call, message="Request timeout error")
except ClientError as e:
raise NetworkError(method=call, message=f"{type(e).__name__}: {e}")
response = self.check_response(method=call, status_code=resp.status, content=raw_result) response = self.check_response(method=call, status_code=resp.status, content=raw_result)
return cast(TelegramType, response.result) return cast(TelegramType, response.result)

View file

@ -10,7 +10,10 @@ from .. import loggers
from ..client.bot import Bot from ..client.bot import Bot
from ..methods import GetUpdates, TelegramMethod from ..methods import GetUpdates, TelegramMethod
from ..types import TelegramObject, Update, User from ..types import TelegramObject, Update, User
from ..utils.backoff import Backoff, BackoffConfig
from ..utils.exceptions.base import TelegramAPIError from ..utils.exceptions.base import TelegramAPIError
from ..utils.exceptions.network import NetworkError
from ..utils.exceptions.server import ServerError
from .event.bases import UNHANDLED, SkipHandler from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver from .event.telegram import TelegramEventObserver
from .fsm.middleware import FSMContextMiddleware from .fsm.middleware import FSMContextMiddleware
@ -21,6 +24,8 @@ from .middlewares.error import ErrorsMiddleware
from .middlewares.user_context import UserContextMiddleware from .middlewares.user_context import UserContextMiddleware
from .router import Router from .router import Router
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
class Dispatcher(Router): class Dispatcher(Router):
""" """
@ -63,7 +68,7 @@ class Dispatcher(Router):
@property @property
def parent_router(self) -> None: def parent_router(self) -> None:
""" """
Dispatcher has no parent router Dispatcher has no parent router and can't be included to any other routers or dispatchers
:return: :return:
""" """
@ -82,6 +87,7 @@ class Dispatcher(Router):
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
""" """
Main entry point for incoming updates Main entry point for incoming updates
Response of this method can be used as Webhook response
:param bot: :param bot:
:param update: :param update:
@ -90,7 +96,7 @@ class Dispatcher(Router):
handled = False handled = False
start_time = loop.time() start_time = loop.time()
Bot.set_current(bot) token = Bot.set_current(bot)
try: try:
response = await self.update.trigger(update, bot=bot, **kwargs) response = await self.update.trigger(update, bot=bot, **kwargs)
handled = response is not UNHANDLED handled = response is not UNHANDLED
@ -105,6 +111,7 @@ class Dispatcher(Router):
duration, duration,
bot.id, bot.id,
) )
Bot.reset_current(token)
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any: async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
""" """
@ -119,20 +126,50 @@ class Dispatcher(Router):
@classmethod @classmethod
async def _listen_updates( async def _listen_updates(
cls, bot: Bot, polling_timeout: int = 30 cls,
bot: Bot,
polling_timeout: int = 30,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
) -> AsyncGenerator[Update, None]: ) -> AsyncGenerator[Update, None]:
""" """
Infinity updates reader Endless updates reader with correctly handling any server-side or connection errors.
So you may not worry that the polling will stop working.
""" """
backoff = Backoff(config=backoff_config)
get_updates = GetUpdates(timeout=polling_timeout) get_updates = GetUpdates(timeout=polling_timeout)
kwargs = {} kwargs = {}
if bot.session.timeout: if bot.session.timeout:
# Request timeout can be lower than session timeout ant that's OK.
# To prevent false-positive TimeoutError we should wait longer than polling timeout
kwargs["request_timeout"] = int(bot.session.timeout + polling_timeout) kwargs["request_timeout"] = int(bot.session.timeout + polling_timeout)
while True: while True:
# TODO: Skip restarting telegram error try:
updates = await bot(get_updates, **kwargs) updates = await bot(get_updates, **kwargs)
except (NetworkError, ServerError) as e:
# In cases when Telegram Bot API was inaccessible don't need to stop polling process
# because some of developers can't make auto-restarting of the script
loggers.dispatcher.error("Failed to fetch updates - %s: %s", type(e).__name__, e)
# And also backoff timeout is best practice to retry any network activity
loggers.dispatcher.warning(
"Sleep for %f seconds and try again... (tryings = %d, bot id = %d)",
backoff.next_delay,
backoff.counter,
bot.id,
)
await backoff.asleep()
continue
# In case when network connection was fixed let's reset the backoff
# to initial value and then process updates
backoff.reset()
for update in updates: for update in updates:
yield update yield update
# The getUpdates method returns the earliest 100 unconfirmed updates.
# To confirm an update, use the offset parameter when calling getUpdates
# All updates with update_id less than or equal to offset will be marked as confirmed on the server
# and will no longer be returned.
get_updates.offset = update.update_id + 1 get_updates.offset = update.update_id + 1
async def _listen_update(self, update: Update, **kwargs: Any) -> Any: async def _listen_update(self, update: Update, **kwargs: Any) -> Any:
@ -255,7 +292,12 @@ class Dispatcher(Router):
return True # because update was processed but unsuccessful return True # because update was processed but unsuccessful
async def _polling( async def _polling(
self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any self,
bot: Bot,
polling_timeout: int = 30,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
**kwargs: Any,
) -> None: ) -> None:
""" """
Internal polling process Internal polling process
@ -264,7 +306,9 @@ class Dispatcher(Router):
:param kwargs: :param kwargs:
:return: :return:
""" """
async for update in self._listen_updates(bot, polling_timeout=polling_timeout): async for update in self._listen_updates(
bot, polling_timeout=polling_timeout, backoff_config=backoff_config
):
handle_update = self._process_update(bot=bot, update=update, **kwargs) handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks: if handle_as_tasks:
asyncio.create_task(handle_update) asyncio.create_task(handle_update)
@ -348,7 +392,12 @@ class Dispatcher(Router):
return None return None
async def start_polling( async def start_polling(
self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any self,
*bots: Bot,
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
**kwargs: Any,
) -> None: ) -> None:
""" """
Polling runner Polling runner
@ -357,6 +406,7 @@ class Dispatcher(Router):
:param polling_timeout: :param polling_timeout:
:param handle_as_tasks: :param handle_as_tasks:
:param kwargs: :param kwargs:
:param backoff_config:
:return: :return:
""" """
async with self._running_lock: # Prevent to run this method twice at a once async with self._running_lock: # Prevent to run this method twice at a once
@ -376,6 +426,7 @@ class Dispatcher(Router):
bot=bot, bot=bot,
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout, polling_timeout=polling_timeout,
backoff_config=backoff_config,
**kwargs, **kwargs,
) )
) )
@ -387,13 +438,19 @@ class Dispatcher(Router):
await self.emit_shutdown(**workflow_data) await self.emit_shutdown(**workflow_data)
def run_polling( def run_polling(
self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any self,
*bots: Bot,
polling_timeout: int = 30,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
**kwargs: Any,
) -> None: ) -> None:
""" """
Run many bots with polling Run many bots with polling
:param bots: Bot instances :param bots: Bot instances
:param polling_timeout: Poling timeout :param polling_timeout: Poling timeout
:param backoff_config:
:param handle_as_tasks: Run task for each event and no wait result :param handle_as_tasks: Run task for each event and no wait result
:param kwargs: contextual data :param kwargs: contextual data
:return: :return:
@ -405,6 +462,7 @@ class Dispatcher(Router):
**kwargs, **kwargs,
polling_timeout=polling_timeout, polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config,
) )
) )
except (KeyboardInterrupt, SystemExit): # pragma: no cover except (KeyboardInterrupt, SystemExit): # pragma: no cover

View file

@ -118,9 +118,7 @@ class Router:
:param router: :param router:
""" """
if not isinstance(router, Router): if not isinstance(router, Router):
raise ValueError( raise ValueError(f"router should be instance of Router not {type(router).__name__!r}")
f"router should be instance of Router not {type(router).__class__.__name__}"
)
if self._parent_router: if self._parent_router:
raise RuntimeError(f"Router is already attached to {self._parent_router!r}") raise RuntimeError(f"Router is already attached to {self._parent_router!r}")
if self == router: if self == router:
@ -133,7 +131,7 @@ class Router:
if not self.use_builtin_filters and parent.use_builtin_filters: if not self.use_builtin_filters and parent.use_builtin_filters:
warnings.warn( warnings.warn(
f"{self.__class__.__name__}(use_builtin_filters=False) has no effect" f"{type(self).__name__}(use_builtin_filters=False) has no effect"
f" for router {self} in due to builtin filters is already registered" f" for router {self} in due to builtin filters is already registered"
f" in parent router", f" in parent router",
CodeHasNoEffect, CodeHasNoEffect,

80
aiogram/utils/backoff.py Normal file
View file

@ -0,0 +1,80 @@
import asyncio
import time
from dataclasses import dataclass
from random import normalvariate
@dataclass(frozen=True)
class BackoffConfig:
min_delay: float
max_delay: float
factor: float
jitter: float
def __post_init__(self):
if self.max_delay <= self.min_delay:
raise ValueError("`max_delay` should be greater than `min_delay`")
if self.factor <= 1:
raise ValueError("`factor` should be greater than 1")
class Backoff:
def __init__(self, config: BackoffConfig) -> None:
self.config = config
self._next_delay = config.min_delay
self._current_delay = 0.0
self._counter = 0
def __iter__(self):
return self
@property
def min_delay(self) -> float:
return self.config.min_delay
@property
def max_delay(self) -> float:
return self.config.max_delay
@property
def factor(self) -> float:
return self.config.factor
@property
def jitter(self) -> float:
return self.config.jitter
@property
def next_delay(self) -> float:
return self._next_delay
@property
def current_delay(self) -> float:
return self._current_delay
@property
def counter(self) -> int:
return self._counter
def sleep(self) -> None:
time.sleep(next(self))
async def asleep(self) -> None:
await asyncio.sleep(next(self))
def _calculate_next(self, value: float) -> float:
return normalvariate(min(value * self.factor, self.max_delay), self.jitter)
def __next__(self) -> float:
self._current_delay = self._next_delay
self._next_delay = self._calculate_next(self._next_delay)
self._counter += 1
return self._current_delay
def reset(self) -> None:
self._current_delay = 0.0
self._counter = 0
self._next_delay = self.min_delay
def __str__(self) -> str:
return f"Backoff(tryings={self._counter}, current_delay={self._current_delay}, next_delay={self._next_delay})"

View file

@ -1,5 +1,5 @@
from aiogram.utils.exceptions.base import DetailedTelegramAPIError from aiogram.utils.exceptions.base import TelegramAPIError
class NetworkError(DetailedTelegramAPIError): class NetworkError(TelegramAPIError):
pass pass

View file

@ -0,0 +1,5 @@
from aiogram.utils.exceptions.base import TelegramAPIError
class ServerError(TelegramAPIError):
pass

View file

@ -0,0 +1,76 @@
import pytest
from aiogram.utils.backoff import Backoff, BackoffConfig
BACKOFF_CONFIG = BackoffConfig(min_delay=0.1, max_delay=1.0, factor=2.0, jitter=0.0)
class TestBackoffConfig:
@pytest.mark.parametrize(
"kwargs",
[
dict(min_delay=1.0, max_delay=1.0, factor=2.0, jitter=0.1), # equals min and max
dict(min_delay=1.0, max_delay=1.0, factor=1.0, jitter=0.1), # factor == 1
dict(min_delay=1.0, max_delay=2.0, factor=0.5, jitter=0.1), # factor < 1
dict(min_delay=2.0, max_delay=1.0, factor=2.0, jitter=0.1), # min > max
],
)
def test_incorrect_post_init(self, kwargs):
with pytest.raises(ValueError):
BackoffConfig(**kwargs)
@pytest.mark.parametrize(
"kwargs",
[dict(min_delay=1.0, max_delay=2.0, factor=1.2, jitter=0.1)],
)
def test_correct_post_init(self, kwargs):
assert BackoffConfig(**kwargs)
class TestBackoff:
def test_aliases(self):
backoff = Backoff(config=BACKOFF_CONFIG)
assert backoff.min_delay == BACKOFF_CONFIG.min_delay
assert backoff.max_delay == BACKOFF_CONFIG.max_delay
assert backoff.factor == BACKOFF_CONFIG.factor
assert backoff.jitter == BACKOFF_CONFIG.jitter
def test_calculation(self):
backoff = Backoff(config=BACKOFF_CONFIG)
index = 0
iterable = iter(backoff)
assert iterable == backoff
assert backoff.current_delay == 0.0
assert backoff.next_delay == 0.1
while (val := next(backoff)) < 1:
index += 1
assert val in {0.1, 0.2, 0.4, 0.8}
assert next(backoff) == 1
assert next(backoff) == 1
assert index == 4
assert backoff.current_delay == 1
assert backoff.next_delay == 1
assert backoff.counter == 7 # 4+1 in while loop + 2 after loop
assert str(backoff) == "Backoff(tryings=7, current_delay=1.0, next_delay=1.0)"
backoff.reset()
assert backoff.current_delay == 0.0
assert backoff.next_delay == 0.1
assert backoff.counter == 0
def test_sleep(self):
backoff = Backoff(config=BACKOFF_CONFIG)
backoff.sleep()
assert backoff.counter == 1
@pytest.mark.asyncio
async def test_asleep(self):
backoff = Backoff(config=BACKOFF_CONFIG)
await backoff.asleep()
assert backoff.counter == 1