mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Move Throttling manager to Dispatcher.
This commit is contained in:
parent
c75e33eaa3
commit
2ebd7c5de4
4 changed files with 125 additions and 155 deletions
|
|
@ -1,20 +1,19 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import time
|
||||
import typing
|
||||
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||
generate_default_filters
|
||||
from .handler import Handler
|
||||
from .storage import BaseStorage, DisabledStorage, FSMContext
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||
from .webhook import BaseResponse
|
||||
from ..bot import Bot
|
||||
from ..types.message import ContentType
|
||||
from ..utils import context
|
||||
from ..utils.deprecated import deprecated
|
||||
from ..utils.exceptions import NetworkError, TelegramAPIError
|
||||
from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,6 +21,8 @@ MODE = 'MODE'
|
|||
LONG_POLLING = 'long-polling'
|
||||
UPDATE_OBJECT = 'update_object'
|
||||
|
||||
DEFAULT_RATE_LIMIT = .1
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""
|
||||
|
|
@ -33,7 +34,9 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||
run_tasks_by_default: bool = False):
|
||||
run_tasks_by_default: bool = False,
|
||||
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
||||
|
||||
if loop is None:
|
||||
loop = bot.loop
|
||||
if storage is None:
|
||||
|
|
@ -44,6 +47,9 @@ class Dispatcher:
|
|||
self.storage = storage
|
||||
self.run_tasks_by_default = run_tasks_by_default
|
||||
|
||||
self.throttling_rate_limit = throttling_rate_limit
|
||||
self.no_throttle_error = no_throttle_error
|
||||
|
||||
self.last_update_id = 0
|
||||
|
||||
self.updates_handler = Handler(self)
|
||||
|
|
@ -929,6 +935,89 @@ class Dispatcher:
|
|||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
|
||||
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||
"""
|
||||
Execute throttling manager.
|
||||
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
||||
|
||||
:param key: key in storage
|
||||
:param rate: limit (by default is equals with default rate limit)
|
||||
:param user: user id
|
||||
:param chat: chat id
|
||||
:param no_error: return boolean value instead of raising error
|
||||
:return: bool
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
print(self.storage)
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if no_error is None:
|
||||
no_error = self.no_throttle_error
|
||||
if rate is None:
|
||||
rate = self.throttling_rate_limit
|
||||
if user is None and chat is None:
|
||||
from . import ctx
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
# Detect current time
|
||||
now = time.time()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
|
||||
# Fix bucket
|
||||
if bucket is None:
|
||||
bucket = {key: {}}
|
||||
if key not in bucket:
|
||||
bucket[key] = {}
|
||||
data = bucket[key]
|
||||
|
||||
# Calculate
|
||||
called = data.get(LAST_CALL, now)
|
||||
delta = now - called
|
||||
result = delta >= rate or delta <= 0
|
||||
|
||||
# Save results
|
||||
data[RESULT] = result
|
||||
data[RATE_LIMIT] = rate
|
||||
data[LAST_CALL] = now
|
||||
data[DELTA] = delta
|
||||
if not result:
|
||||
data[EXCEEDED_COUNT] += 1
|
||||
else:
|
||||
data[EXCEEDED_COUNT] = 1
|
||||
bucket[key].update(data)
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
|
||||
if not result and not no_error:
|
||||
# Raise if that is allowed
|
||||
raise Throttled(key=key, chat=chat, user=user, **data)
|
||||
return result
|
||||
|
||||
async def release_key(self, key, chat=None, user=None):
|
||||
"""
|
||||
Release blocked key
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
from . import ctx
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
if bucket and key in bucket:
|
||||
del bucket['key']
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
return True
|
||||
return False
|
||||
|
||||
def async_task(self, func):
|
||||
"""
|
||||
Execute handler as task and return None.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,14 @@
|
|||
import typing
|
||||
|
||||
# Leak bucket
|
||||
KEY = 'key'
|
||||
LAST_CALL = 'called_at'
|
||||
RATE_LIMIT = 'rate_limit'
|
||||
RESULT = 'result'
|
||||
EXCEEDED_COUNT = 'exceeded'
|
||||
DELTA = 'delta'
|
||||
THROTTLE_MANAGER = '$throttle_manager'
|
||||
|
||||
|
||||
class BaseStorage:
|
||||
"""
|
||||
|
|
@ -216,7 +225,7 @@ class BaseStorage:
|
|||
|
||||
:param chat:
|
||||
:param user:
|
||||
:param data:
|
||||
:param bucket:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -233,7 +242,7 @@ class BaseStorage:
|
|||
Chat or user is always required. If one of this is not presented,
|
||||
need set the missing value based on the presented
|
||||
|
||||
:param data:
|
||||
:param bucket:
|
||||
:param chat:
|
||||
:param user:
|
||||
:param kwargs:
|
||||
|
|
|
|||
|
|
@ -1,148 +0,0 @@
|
|||
import asyncio
|
||||
import time
|
||||
|
||||
from aiogram.dispatcher import BaseStorage, Dispatcher, ctx
|
||||
|
||||
DEFAULT_RATE_LIMIT = .1
|
||||
|
||||
KEY = 'key'
|
||||
LAST_CALL = 'called_at'
|
||||
RATE_LIMIT = 'rate_limit'
|
||||
RESULT = 'result'
|
||||
EXCEEDED_COUNT = 'exceeded'
|
||||
DELTA = 'delta'
|
||||
THROTTLE_MANAGER = '$throttle_manager'
|
||||
|
||||
|
||||
class ThrottleError(Exception):
|
||||
def __init__(self, **kwargs):
|
||||
self.key = kwargs.pop(KEY, '<None>')
|
||||
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
||||
self.rate = kwargs.pop(RATE_LIMIT, None)
|
||||
self.result = kwargs.pop(RESULT, False)
|
||||
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
||||
self.delta = kwargs.pop(DELTA, 0)
|
||||
self.user = kwargs.pop('user', None)
|
||||
self.chat = kwargs.pop('chat', None)
|
||||
|
||||
def __str__(self):
|
||||
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
||||
f"exceeded: {self.exceeded_count}, " \
|
||||
f"time delta: {round(self.delta, 3)} s)"
|
||||
|
||||
|
||||
class Bucket:
|
||||
"""
|
||||
Throttling manager
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher, rate_limit=DEFAULT_RATE_LIMIT, storage: BaseStorage = None, no_error=False):
|
||||
"""
|
||||
Initialize throttle manager
|
||||
|
||||
:param dispatcher: instance of Dispatcher
|
||||
:param rate_limit: limit in seconds
|
||||
:param storage:
|
||||
:param no_error: return boolean value instead of raising error
|
||||
"""
|
||||
if storage is None:
|
||||
storage = dispatcher.storage
|
||||
if not storage.has_bucket():
|
||||
raise TypeError('This storage does not provide Bucket!')
|
||||
|
||||
self._dispatcher: Dispatcher = dispatcher
|
||||
self._loop: asyncio.BaseEventLoop = self._dispatcher.loop
|
||||
self._rate_limit = rate_limit
|
||||
self._storage = storage
|
||||
self._no_error = no_error
|
||||
|
||||
dispatcher.bot[THROTTLE_MANAGER] = self
|
||||
|
||||
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||
"""
|
||||
Execute throttling manager.
|
||||
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
||||
|
||||
:param key: key in storage
|
||||
:param rate: limit (by default is equals with default rate limit)
|
||||
:param user: user id
|
||||
:param chat: chat id
|
||||
:param no_error: return boolean value instead of raising error
|
||||
:return: bool
|
||||
"""
|
||||
if no_error is None:
|
||||
no_error = self._no_error
|
||||
if rate is None:
|
||||
rate = self._rate_limit
|
||||
if user is None and chat is None:
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
# Detect current time
|
||||
now = time.time()
|
||||
|
||||
bucket = await self._storage.get_bucket(chat=chat, user=user)
|
||||
|
||||
# Fix bucket
|
||||
if bucket is None:
|
||||
bucket = {key: {}}
|
||||
if key not in bucket:
|
||||
bucket[key] = {}
|
||||
data = bucket[key]
|
||||
|
||||
# Calculate
|
||||
called = data.get(LAST_CALL, now)
|
||||
delta = now - called
|
||||
result = delta >= rate or delta <= 0
|
||||
|
||||
# Save results
|
||||
data[RESULT] = result
|
||||
data[RATE_LIMIT] = rate
|
||||
data[LAST_CALL] = now
|
||||
data[DELTA] = delta
|
||||
if not result:
|
||||
data[EXCEEDED_COUNT] += 1
|
||||
else:
|
||||
data[EXCEEDED_COUNT] = 1
|
||||
bucket[key].update(data)
|
||||
await self._storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
|
||||
if not result and not no_error:
|
||||
# Raise if that is allowed
|
||||
raise ThrottleError(key=key, chat=chat, user=user, **data)
|
||||
return result
|
||||
|
||||
async def release_key(self, key, chat=None, user=None):
|
||||
"""
|
||||
Release blocked key
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
if user is None and chat is None:
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
bucket = await self._storage.get_bucket(chat=chat, user=user)
|
||||
if bucket and key in bucket:
|
||||
del bucket['key']
|
||||
await self._storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def throttle(key, rate=None, no_error=None):
|
||||
"""
|
||||
Alias for Bucket.throttle(...)
|
||||
|
||||
:param key:
|
||||
:param rate:
|
||||
:param no_error:
|
||||
:return:
|
||||
"""
|
||||
bot = ctx.get_bot()
|
||||
bucket = bot.get(THROTTLE_MANAGER)
|
||||
if not bucket:
|
||||
raise RuntimeError('Can\'t be found Bucket!')
|
||||
return await bucket.throttle(key=key, rate=rate, no_error=no_error)
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import time
|
||||
|
||||
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
||||
|
||||
|
||||
|
|
@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError):
|
|||
def __init__(self, chat_id):
|
||||
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
||||
self.migrate_to_chat_id = chat_id
|
||||
|
||||
|
||||
class Throttled(Exception):
|
||||
def __init__(self, **kwargs):
|
||||
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
|
||||
self.key = kwargs.pop(KEY, '<None>')
|
||||
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
||||
self.rate = kwargs.pop(RATE_LIMIT, None)
|
||||
self.result = kwargs.pop(RESULT, False)
|
||||
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
||||
self.delta = kwargs.pop(DELTA, 0)
|
||||
self.user = kwargs.pop('user', None)
|
||||
self.chat = kwargs.pop('chat', None)
|
||||
|
||||
def __str__(self):
|
||||
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
||||
f"exceeded: {self.exceeded_count}, " \
|
||||
f"time delta: {round(self.delta, 3)} s)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue