mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Move Throttling manager to Dispatcher.
This commit is contained in:
parent
c75e33eaa3
commit
2ebd7c5de4
4 changed files with 125 additions and 155 deletions
|
|
@ -1,20 +1,19 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import typing
|
||||||
|
|
||||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||||
generate_default_filters
|
generate_default_filters
|
||||||
from .handler import Handler
|
from .handler import Handler
|
||||||
from .storage import BaseStorage, DisabledStorage, FSMContext
|
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||||
from .webhook import BaseResponse
|
from .webhook import BaseResponse
|
||||||
from ..bot import Bot
|
from ..bot import Bot
|
||||||
from ..types.message import ContentType
|
from ..types.message import ContentType
|
||||||
from ..utils import context
|
from ..utils import context
|
||||||
from ..utils.deprecated import deprecated
|
from ..utils.deprecated import deprecated
|
||||||
from ..utils.exceptions import NetworkError, TelegramAPIError
|
from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -22,6 +21,8 @@ MODE = 'MODE'
|
||||||
LONG_POLLING = 'long-polling'
|
LONG_POLLING = 'long-polling'
|
||||||
UPDATE_OBJECT = 'update_object'
|
UPDATE_OBJECT = 'update_object'
|
||||||
|
|
||||||
|
DEFAULT_RATE_LIMIT = .1
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,7 +34,9 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||||
run_tasks_by_default: bool = False):
|
run_tasks_by_default: bool = False,
|
||||||
|
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
||||||
|
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = bot.loop
|
loop = bot.loop
|
||||||
if storage is None:
|
if storage is None:
|
||||||
|
|
@ -44,6 +47,9 @@ class Dispatcher:
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.run_tasks_by_default = run_tasks_by_default
|
self.run_tasks_by_default = run_tasks_by_default
|
||||||
|
|
||||||
|
self.throttling_rate_limit = throttling_rate_limit
|
||||||
|
self.no_throttle_error = no_throttle_error
|
||||||
|
|
||||||
self.last_update_id = 0
|
self.last_update_id = 0
|
||||||
|
|
||||||
self.updates_handler = Handler(self)
|
self.updates_handler = Handler(self)
|
||||||
|
|
@ -929,6 +935,89 @@ class Dispatcher:
|
||||||
|
|
||||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||||
|
|
||||||
|
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||||
|
"""
|
||||||
|
Execute throttling manager.
|
||||||
|
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
||||||
|
|
||||||
|
:param key: key in storage
|
||||||
|
:param rate: limit (by default is equals with default rate limit)
|
||||||
|
:param user: user id
|
||||||
|
:param chat: chat id
|
||||||
|
:param no_error: return boolean value instead of raising error
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if not self.storage.has_bucket():
|
||||||
|
print(self.storage)
|
||||||
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
|
if no_error is None:
|
||||||
|
no_error = self.no_throttle_error
|
||||||
|
if rate is None:
|
||||||
|
rate = self.throttling_rate_limit
|
||||||
|
if user is None and chat is None:
|
||||||
|
from . import ctx
|
||||||
|
user = ctx.get_user()
|
||||||
|
chat = ctx.get_chat()
|
||||||
|
|
||||||
|
# Detect current time
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
|
|
||||||
|
# Fix bucket
|
||||||
|
if bucket is None:
|
||||||
|
bucket = {key: {}}
|
||||||
|
if key not in bucket:
|
||||||
|
bucket[key] = {}
|
||||||
|
data = bucket[key]
|
||||||
|
|
||||||
|
# Calculate
|
||||||
|
called = data.get(LAST_CALL, now)
|
||||||
|
delta = now - called
|
||||||
|
result = delta >= rate or delta <= 0
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
data[RESULT] = result
|
||||||
|
data[RATE_LIMIT] = rate
|
||||||
|
data[LAST_CALL] = now
|
||||||
|
data[DELTA] = delta
|
||||||
|
if not result:
|
||||||
|
data[EXCEEDED_COUNT] += 1
|
||||||
|
else:
|
||||||
|
data[EXCEEDED_COUNT] = 1
|
||||||
|
bucket[key].update(data)
|
||||||
|
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
|
||||||
|
if not result and not no_error:
|
||||||
|
# Raise if that is allowed
|
||||||
|
raise Throttled(key=key, chat=chat, user=user, **data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def release_key(self, key, chat=None, user=None):
|
||||||
|
"""
|
||||||
|
Release blocked key
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self.storage.has_bucket():
|
||||||
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
|
if user is None and chat is None:
|
||||||
|
from . import ctx
|
||||||
|
user = ctx.get_user()
|
||||||
|
chat = ctx.get_chat()
|
||||||
|
|
||||||
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
|
if bucket and key in bucket:
|
||||||
|
del bucket['key']
|
||||||
|
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def async_task(self, func):
|
def async_task(self, func):
|
||||||
"""
|
"""
|
||||||
Execute handler as task and return None.
|
Execute handler as task and return None.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,14 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
# Leak bucket
|
||||||
|
KEY = 'key'
|
||||||
|
LAST_CALL = 'called_at'
|
||||||
|
RATE_LIMIT = 'rate_limit'
|
||||||
|
RESULT = 'result'
|
||||||
|
EXCEEDED_COUNT = 'exceeded'
|
||||||
|
DELTA = 'delta'
|
||||||
|
THROTTLE_MANAGER = '$throttle_manager'
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage:
|
class BaseStorage:
|
||||||
"""
|
"""
|
||||||
|
|
@ -216,7 +225,7 @@ class BaseStorage:
|
||||||
|
|
||||||
:param chat:
|
:param chat:
|
||||||
:param user:
|
:param user:
|
||||||
:param data:
|
:param bucket:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -233,7 +242,7 @@ class BaseStorage:
|
||||||
Chat or user is always required. If one of this is not presented,
|
Chat or user is always required. If one of this is not presented,
|
||||||
need set the missing value based on the presented
|
need set the missing value based on the presented
|
||||||
|
|
||||||
:param data:
|
:param bucket:
|
||||||
:param chat:
|
:param chat:
|
||||||
:param user:
|
:param user:
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
|
|
|
||||||
|
|
@ -1,148 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
from aiogram.dispatcher import BaseStorage, Dispatcher, ctx
|
|
||||||
|
|
||||||
DEFAULT_RATE_LIMIT = .1
|
|
||||||
|
|
||||||
KEY = 'key'
|
|
||||||
LAST_CALL = 'called_at'
|
|
||||||
RATE_LIMIT = 'rate_limit'
|
|
||||||
RESULT = 'result'
|
|
||||||
EXCEEDED_COUNT = 'exceeded'
|
|
||||||
DELTA = 'delta'
|
|
||||||
THROTTLE_MANAGER = '$throttle_manager'
|
|
||||||
|
|
||||||
|
|
||||||
class ThrottleError(Exception):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.key = kwargs.pop(KEY, '<None>')
|
|
||||||
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
|
||||||
self.rate = kwargs.pop(RATE_LIMIT, None)
|
|
||||||
self.result = kwargs.pop(RESULT, False)
|
|
||||||
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
|
||||||
self.delta = kwargs.pop(DELTA, 0)
|
|
||||||
self.user = kwargs.pop('user', None)
|
|
||||||
self.chat = kwargs.pop('chat', None)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
|
||||||
f"exceeded: {self.exceeded_count}, " \
|
|
||||||
f"time delta: {round(self.delta, 3)} s)"
|
|
||||||
|
|
||||||
|
|
||||||
class Bucket:
|
|
||||||
"""
|
|
||||||
Throttling manager
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dispatcher, rate_limit=DEFAULT_RATE_LIMIT, storage: BaseStorage = None, no_error=False):
|
|
||||||
"""
|
|
||||||
Initialize throttle manager
|
|
||||||
|
|
||||||
:param dispatcher: instance of Dispatcher
|
|
||||||
:param rate_limit: limit in seconds
|
|
||||||
:param storage:
|
|
||||||
:param no_error: return boolean value instead of raising error
|
|
||||||
"""
|
|
||||||
if storage is None:
|
|
||||||
storage = dispatcher.storage
|
|
||||||
if not storage.has_bucket():
|
|
||||||
raise TypeError('This storage does not provide Bucket!')
|
|
||||||
|
|
||||||
self._dispatcher: Dispatcher = dispatcher
|
|
||||||
self._loop: asyncio.BaseEventLoop = self._dispatcher.loop
|
|
||||||
self._rate_limit = rate_limit
|
|
||||||
self._storage = storage
|
|
||||||
self._no_error = no_error
|
|
||||||
|
|
||||||
dispatcher.bot[THROTTLE_MANAGER] = self
|
|
||||||
|
|
||||||
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
|
||||||
"""
|
|
||||||
Execute throttling manager.
|
|
||||||
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
|
||||||
|
|
||||||
:param key: key in storage
|
|
||||||
:param rate: limit (by default is equals with default rate limit)
|
|
||||||
:param user: user id
|
|
||||||
:param chat: chat id
|
|
||||||
:param no_error: return boolean value instead of raising error
|
|
||||||
:return: bool
|
|
||||||
"""
|
|
||||||
if no_error is None:
|
|
||||||
no_error = self._no_error
|
|
||||||
if rate is None:
|
|
||||||
rate = self._rate_limit
|
|
||||||
if user is None and chat is None:
|
|
||||||
user = ctx.get_user()
|
|
||||||
chat = ctx.get_chat()
|
|
||||||
|
|
||||||
# Detect current time
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
bucket = await self._storage.get_bucket(chat=chat, user=user)
|
|
||||||
|
|
||||||
# Fix bucket
|
|
||||||
if bucket is None:
|
|
||||||
bucket = {key: {}}
|
|
||||||
if key not in bucket:
|
|
||||||
bucket[key] = {}
|
|
||||||
data = bucket[key]
|
|
||||||
|
|
||||||
# Calculate
|
|
||||||
called = data.get(LAST_CALL, now)
|
|
||||||
delta = now - called
|
|
||||||
result = delta >= rate or delta <= 0
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
data[RESULT] = result
|
|
||||||
data[RATE_LIMIT] = rate
|
|
||||||
data[LAST_CALL] = now
|
|
||||||
data[DELTA] = delta
|
|
||||||
if not result:
|
|
||||||
data[EXCEEDED_COUNT] += 1
|
|
||||||
else:
|
|
||||||
data[EXCEEDED_COUNT] = 1
|
|
||||||
bucket[key].update(data)
|
|
||||||
await self._storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
|
||||||
|
|
||||||
if not result and not no_error:
|
|
||||||
# Raise if that is allowed
|
|
||||||
raise ThrottleError(key=key, chat=chat, user=user, **data)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def release_key(self, key, chat=None, user=None):
|
|
||||||
"""
|
|
||||||
Release blocked key
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param chat:
|
|
||||||
:param user:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if user is None and chat is None:
|
|
||||||
user = ctx.get_user()
|
|
||||||
chat = ctx.get_chat()
|
|
||||||
bucket = await self._storage.get_bucket(chat=chat, user=user)
|
|
||||||
if bucket and key in bucket:
|
|
||||||
del bucket['key']
|
|
||||||
await self._storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def throttle(key, rate=None, no_error=None):
|
|
||||||
"""
|
|
||||||
Alias for Bucket.throttle(...)
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param rate:
|
|
||||||
:param no_error:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
bot = ctx.get_bot()
|
|
||||||
bucket = bot.get(THROTTLE_MANAGER)
|
|
||||||
if not bucket:
|
|
||||||
raise RuntimeError('Can\'t be found Bucket!')
|
|
||||||
return await bucket.throttle(key=key, rate=rate, no_error=no_error)
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import time
|
||||||
|
|
||||||
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError):
|
||||||
def __init__(self, chat_id):
|
def __init__(self, chat_id):
|
||||||
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
||||||
self.migrate_to_chat_id = chat_id
|
self.migrate_to_chat_id = chat_id
|
||||||
|
|
||||||
|
|
||||||
|
class Throttled(Exception):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
|
||||||
|
self.key = kwargs.pop(KEY, '<None>')
|
||||||
|
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
||||||
|
self.rate = kwargs.pop(RATE_LIMIT, None)
|
||||||
|
self.result = kwargs.pop(RESULT, False)
|
||||||
|
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
||||||
|
self.delta = kwargs.pop(DELTA, 0)
|
||||||
|
self.user = kwargs.pop('user', None)
|
||||||
|
self.chat = kwargs.pop('chat', None)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
||||||
|
f"exceeded: {self.exceeded_count}, " \
|
||||||
|
f"time delta: {round(self.delta, 3)} s)"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue