diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 76fb3b4d..b4179421 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,20 +1,19 @@ import asyncio import functools import logging -import typing - import time +import typing from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \ generate_default_filters from .handler import Handler -from .storage import BaseStorage, DisabledStorage, FSMContext +from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT from .webhook import BaseResponse from ..bot import Bot from ..types.message import ContentType from ..utils import context from ..utils.deprecated import deprecated -from ..utils.exceptions import NetworkError, TelegramAPIError +from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled log = logging.getLogger(__name__) @@ -22,6 +21,8 @@ MODE = 'MODE' LONG_POLLING = 'long-polling' UPDATE_OBJECT = 'update_object' +DEFAULT_RATE_LIMIT = .1 + class Dispatcher: """ @@ -33,7 +34,9 @@ class Dispatcher: """ def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, - run_tasks_by_default: bool = False): + run_tasks_by_default: bool = False, + throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False): + if loop is None: loop = bot.loop if storage is None: @@ -44,6 +47,9 @@ class Dispatcher: self.storage = storage self.run_tasks_by_default = run_tasks_by_default + self.throttling_rate_limit = throttling_rate_limit + self.no_throttle_error = no_throttle_error + self.last_update_id = 0 self.updates_handler = Handler(self) @@ -929,6 +935,89 @@ class Dispatcher: return FSMContext(storage=self.storage, chat=chat, user=user) + async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: + """ + Execute throttling manager. + Return True limit is not exceeded otherwise raise ThrottleError or return False + + :param key: key in storage + :param rate: limit (by default is equals with default rate limit) + :param user: user id + :param chat: chat id + :param no_error: return boolean value instead of raising error + :return: bool + """ + if not self.storage.has_bucket(): + print(self.storage) + raise RuntimeError('This storage does not provide Leaky Bucket') + + if no_error is None: + no_error = self.no_throttle_error + if rate is None: + rate = self.throttling_rate_limit + if user is None and chat is None: + from . import ctx + user = ctx.get_user() + chat = ctx.get_chat() + + # Detect current time + now = time.time() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + + # Fix bucket + if bucket is None: + bucket = {key: {}} + if key not in bucket: + bucket[key] = {} + data = bucket[key] + + # Calculate + called = data.get(LAST_CALL, now) + delta = now - called + result = delta >= rate or delta <= 0 + + # Save results + data[RESULT] = result + data[RATE_LIMIT] = rate + data[LAST_CALL] = now + data[DELTA] = delta + if not result: + data[EXCEEDED_COUNT] += 1 + else: + data[EXCEEDED_COUNT] = 1 + bucket[key].update(data) + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + + if not result and not no_error: + # Raise if that is allowed + raise Throttled(key=key, chat=chat, user=user, **data) + return result + + async def release_key(self, key, chat=None, user=None): + """ + Release blocked key + + :param key: + :param chat: + :param user: + :return: + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if user is None and chat is None: + from . import ctx + user = ctx.get_user() + chat = ctx.get_chat() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + if bucket and key in bucket: + del bucket['key'] + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + return True + return False + def async_task(self, func): """ Execute handler as task and return None. diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py index d7d150a9..622e4182 100644 --- a/aiogram/dispatcher/storage.py +++ b/aiogram/dispatcher/storage.py @@ -1,5 +1,14 @@ import typing +# Leak bucket +KEY = 'key' +LAST_CALL = 'called_at' +RATE_LIMIT = 'rate_limit' +RESULT = 'result' +EXCEEDED_COUNT = 'exceeded' +DELTA = 'delta' +THROTTLE_MANAGER = '$throttle_manager' + class BaseStorage: """ @@ -216,7 +225,7 @@ class BaseStorage: :param chat: :param user: - :param data: + :param bucket: """ raise NotImplementedError @@ -233,7 +242,7 @@ class BaseStorage: Chat or user is always required. If one of this is not presented, need set the missing value based on the presented - :param data: + :param bucket: :param chat: :param user: :param kwargs: diff --git a/aiogram/dispatcher/throttle.py b/aiogram/dispatcher/throttle.py deleted file mode 100644 index d655c9e2..00000000 --- a/aiogram/dispatcher/throttle.py +++ /dev/null @@ -1,148 +0,0 @@ -import asyncio -import time - -from aiogram.dispatcher import BaseStorage, Dispatcher, ctx - -DEFAULT_RATE_LIMIT = .1 - -KEY = 'key' -LAST_CALL = 'called_at' -RATE_LIMIT = 'rate_limit' -RESULT = 'result' -EXCEEDED_COUNT = 'exceeded' -DELTA = 'delta' -THROTTLE_MANAGER = '$throttle_manager' - - -class ThrottleError(Exception): - def __init__(self, **kwargs): - self.key = kwargs.pop(KEY, '') - self.called_at = kwargs.pop(LAST_CALL, time.time()) - self.rate = kwargs.pop(RATE_LIMIT, None) - self.result = kwargs.pop(RESULT, False) - self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0) - self.delta = kwargs.pop(DELTA, 0) - self.user = kwargs.pop('user', None) - self.chat = kwargs.pop('chat', None) - - def __str__(self): - return f"Rate limit exceeded! (Limit: {self.rate} s, " \ - f"exceeded: {self.exceeded_count}, " \ - f"time delta: {round(self.delta, 3)} s)" - - -class Bucket: - """ - Throttling manager - """ - - def __init__(self, dispatcher, rate_limit=DEFAULT_RATE_LIMIT, storage: BaseStorage = None, no_error=False): - """ - Initialize throttle manager - - :param dispatcher: instance of Dispatcher - :param rate_limit: limit in seconds - :param storage: - :param no_error: return boolean value instead of raising error - """ - if storage is None: - storage = dispatcher.storage - if not storage.has_bucket(): - raise TypeError('This storage does not provide Bucket!') - - self._dispatcher: Dispatcher = dispatcher - self._loop: asyncio.BaseEventLoop = self._dispatcher.loop - self._rate_limit = rate_limit - self._storage = storage - self._no_error = no_error - - dispatcher.bot[THROTTLE_MANAGER] = self - - async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: - """ - Execute throttling manager. - Return True limit is not exceeded otherwise raise ThrottleError or return False - - :param key: key in storage - :param rate: limit (by default is equals with default rate limit) - :param user: user id - :param chat: chat id - :param no_error: return boolean value instead of raising error - :return: bool - """ - if no_error is None: - no_error = self._no_error - if rate is None: - rate = self._rate_limit - if user is None and chat is None: - user = ctx.get_user() - chat = ctx.get_chat() - - # Detect current time - now = time.time() - - bucket = await self._storage.get_bucket(chat=chat, user=user) - - # Fix bucket - if bucket is None: - bucket = {key: {}} - if key not in bucket: - bucket[key] = {} - data = bucket[key] - - # Calculate - called = data.get(LAST_CALL, now) - delta = now - called - result = delta >= rate or delta <= 0 - - # Save results - data[RESULT] = result - data[RATE_LIMIT] = rate - data[LAST_CALL] = now - data[DELTA] = delta - if not result: - data[EXCEEDED_COUNT] += 1 - else: - data[EXCEEDED_COUNT] = 1 - bucket[key].update(data) - await self._storage.set_bucket(chat=chat, user=user, bucket=bucket) - - if not result and not no_error: - # Raise if that is allowed - raise ThrottleError(key=key, chat=chat, user=user, **data) - return result - - async def release_key(self, key, chat=None, user=None): - """ - Release blocked key - - :param key: - :param chat: - :param user: - :return: - """ - if user is None and chat is None: - user = ctx.get_user() - chat = ctx.get_chat() - bucket = await self._storage.get_bucket(chat=chat, user=user) - if bucket and key in bucket: - del bucket['key'] - await self._storage.set_bucket(chat=chat, user=user, bucket=bucket) - return True - return False - - -async def throttle(key, rate=None, no_error=None): - """ - Alias for Bucket.throttle(...) - - :param key: - :param rate: - :param no_error: - :return: - """ - bot = ctx.get_bot() - bucket = bot.get(THROTTLE_MANAGER) - if not bucket: - raise RuntimeError('Can\'t be found Bucket!') - return await bucket.throttle(key=key, rate=rate, no_error=no_error) diff --git a/aiogram/utils/exceptions.py b/aiogram/utils/exceptions.py index edbf5bcd..d0a40c0d 100644 --- a/aiogram/utils/exceptions.py +++ b/aiogram/utils/exceptions.py @@ -1,3 +1,5 @@ +import time + _PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: '] @@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError): def __init__(self, chat_id): super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.") self.migrate_to_chat_id = chat_id + + +class Throttled(Exception): + def __init__(self, **kwargs): + from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT + self.key = kwargs.pop(KEY, '') + self.called_at = kwargs.pop(LAST_CALL, time.time()) + self.rate = kwargs.pop(RATE_LIMIT, None) + self.result = kwargs.pop(RESULT, False) + self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0) + self.delta = kwargs.pop(DELTA, 0) + self.user = kwargs.pop('user', None) + self.chat = kwargs.pop('chat', None) + + def __str__(self): + return f"Rate limit exceeded! (Limit: {self.rate} s, " \ + f"exceeded: {self.exceeded_count}, " \ + f"time delta: {round(self.delta, 3)} s)"