mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-08 17:13:56 +00:00
Fix Dispatcher.throttle(...) and rename user & chat arguments to user_id & chat_id
This commit is contained in:
parent
7f43bf8a65
commit
3d4bdcc498
1 changed files with 33 additions and 26 deletions
|
|
@ -8,6 +8,7 @@ import typing
|
|||
import aiohttp
|
||||
from aiohttp.helpers import sentinel
|
||||
|
||||
from aiogram.utils.deprecated import renamed_argument
|
||||
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
|
||||
RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter
|
||||
from .handler import Handler
|
||||
|
|
@ -914,15 +915,17 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
|
||||
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||
@renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=4)
|
||||
@renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4)
|
||||
async def throttle(self, key, *, rate=None, user_id=None, chat_id=None, no_error=None) -> bool:
|
||||
"""
|
||||
Execute throttling manager.
|
||||
Returns True if limit has not exceeded otherwise raises ThrottleError or returns False
|
||||
|
||||
:param key: key in storage
|
||||
:param rate: limit (by default is equal to default rate limit)
|
||||
:param user: user id
|
||||
:param chat: chat id
|
||||
:param user_id: user id
|
||||
:param chat_id: chat id
|
||||
:param no_error: return boolean value instead of raising error
|
||||
:return: bool
|
||||
"""
|
||||
|
|
@ -933,14 +936,14 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
no_error = self.no_throttle_error
|
||||
if rate is None:
|
||||
rate = self.throttling_rate_limit
|
||||
if user is None and chat is None:
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
if user_id is None and chat_id is None:
|
||||
user_id = types.User.get_current().id
|
||||
chat_id = types.Chat.get_current().id
|
||||
|
||||
# Detect current time
|
||||
now = time.time()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
bucket = await self.storage.get_bucket(chat=chat_id, user=user_id)
|
||||
|
||||
# Fix bucket
|
||||
if bucket is None:
|
||||
|
|
@ -964,53 +967,57 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
else:
|
||||
data[EXCEEDED_COUNT] = 1
|
||||
bucket[key].update(data)
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
await self.storage.set_bucket(chat=chat_id, user=user_id, bucket=bucket)
|
||||
|
||||
if not result and not no_error:
|
||||
# Raise if it is allowed
|
||||
raise Throttled(key=key, chat=chat, user=user, **data)
|
||||
raise Throttled(key=key, chat=chat_id, user=user_id, **data)
|
||||
return result
|
||||
|
||||
async def check_key(self, key, chat=None, user=None):
|
||||
@renamed_argument('user', 'user_id', '3.0')
|
||||
@renamed_argument('chat', 'chat_id', '3.0')
|
||||
async def check_key(self, key, chat_id=None, user_id=None):
|
||||
"""
|
||||
Get information about key in bucket
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:param chat_id:
|
||||
:param user_id:
|
||||
:return:
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
if user_id is None and chat_id is None:
|
||||
user_id = types.User.get_current()
|
||||
chat_id = types.Chat.get_current()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
bucket = await self.storage.get_bucket(chat=chat_id, user=user_id)
|
||||
data = bucket.get(key, {})
|
||||
return Throttled(key=key, chat=chat, user=user, **data)
|
||||
return Throttled(key=key, chat=chat_id, user=user_id, **data)
|
||||
|
||||
async def release_key(self, key, chat=None, user=None):
|
||||
@renamed_argument('user', 'user_id', '3.0')
|
||||
@renamed_argument('chat', 'chat_id', '3.0')
|
||||
async def release_key(self, key, chat_id=None, user_id=None):
|
||||
"""
|
||||
Release blocked key
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:param chat_id:
|
||||
:param user_id:
|
||||
:return:
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
if user_id is None and chat_id is None:
|
||||
user_id = types.User.get_current()
|
||||
chat_id = types.Chat.get_current()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
bucket = await self.storage.get_bucket(chat=chat_id, user=user_id)
|
||||
if bucket and key in bucket:
|
||||
del bucket['key']
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
await self.storage.set_bucket(chat=chat_id, user=user_id, bucket=bucket)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -1086,7 +1093,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
async def wrapped(*args, **kwargs):
|
||||
is_not_throttled = await self.throttle(key if key is not None else func.__name__,
|
||||
rate=rate,
|
||||
user=user_id, chat=chat_id,
|
||||
user_id=user_id, chat_id=chat_id,
|
||||
no_error=True)
|
||||
if is_not_throttled:
|
||||
return await func(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue