diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 95a6d8d9..a5bf5b9f 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -1055,3 +1055,64 @@ class Dispatcher(DataMixin, ContextInstanceMixin): if run_task: return self.async_task(callback) return callback + + def throttled(self, on_throttled: typing.Optional[typing.Callable] = None, + key=None, rate=None, + user_id=None, chat_id=None): + """ + Meta-decorator for throttling. + Invokes on_throttled if the handler was throttled. + + Example: + + .. code-block:: python3 + + async def handler_throttled(message: types.Message, **kwargs): + await message.answer("Throttled!") + + @dp.throttled(handler_throttled) + async def some_handler(message: types.Message): + await message.answer("Didn't throttled!") + + :param on_throttled: the callable object that should be either a function or return a coroutine + :param key: key in storage + :param rate: limit (by default is equal to default rate limit) + :param user_id: user id + :param chat_id: chat id + :return: decorator + """ + def decorator(func): + @functools.wraps(func) + async def wrapped(*args, **kwargs): + is_not_throttled = await self.throttle(key if key is not None else func.__name__, + rate=rate, + user=user_id, chat=chat_id, + no_error=True) + if is_not_throttled: + return await func(*args, **kwargs) + else: + kwargs.update( + { + 'rate': rate, + 'key': key, + 'user_id': user_id, + 'chat_id': chat_id + } + ) # update kwargs with parameters which were given to throttled + + if on_throttled: + if asyncio.iscoroutinefunction(on_throttled): + await on_throttled(*args, **kwargs) + else: + kwargs.update( + { + 'loop': asyncio.get_running_loop() + } + ) + partial_func = functools.partial(on_throttled, *args, **kwargs) + asyncio.get_running_loop().run_in_executor(None, + partial_func + ) + return wrapped + + return decorator diff --git a/examples/throtling_example.py b/examples/throtling_example.py deleted file mode 100644 index 18563472..00000000 --- a/examples/throtling_example.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Example for throttling manager. - -You can use that for flood controlling. -""" - -import logging - -from aiogram import Bot, types -from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.dispatcher import Dispatcher -from aiogram.utils.exceptions import Throttled -from aiogram.utils.executor import start_polling - - -API_TOKEN = 'BOT_TOKEN_HERE' - -logging.basicConfig(level=logging.INFO) - -bot = Bot(token=API_TOKEN) - -# Throttling manager does not work without Leaky Bucket. -# You need to use a storage. For example use simple in-memory storage. -storage = MemoryStorage() -dp = Dispatcher(bot, storage=storage) - - -@dp.message_handler(commands=['start', 'help']) -async def send_welcome(message: types.Message): - try: - # Execute throttling manager with rate-limit equal to 2 seconds for key "start" - await dp.throttle('start', rate=2) - except Throttled: - # If request is throttled, the `Throttled` exception will be raised - await message.reply('Too many requests!') - else: - # Otherwise do something - await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") - - -if __name__ == '__main__': - start_polling(dp, skip_updates=True) diff --git a/examples/throttling_example.py b/examples/throttling_example.py new file mode 100644 index 00000000..f9ad1c67 --- /dev/null +++ b/examples/throttling_example.py @@ -0,0 +1,71 @@ +""" +Example for throttling manager. + +You can use that for flood controlling. +""" + +import logging + +from aiogram import Bot, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import Dispatcher +from aiogram.utils.exceptions import Throttled +from aiogram.utils.executor import start_polling + + +API_TOKEN = 'BOT_TOKEN_HERE' + +logging.basicConfig(level=logging.INFO) + +bot = Bot(token=API_TOKEN) + +# Throttling manager does not work without Leaky Bucket. +# You need to use a storage. For example use simple in-memory storage. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) + + +@dp.message_handler(commands=['start']) +async def send_welcome(message: types.Message): + try: + # Execute throttling manager with rate-limit equal to 2 seconds for key "start" + await dp.throttle('start', rate=2) + except Throttled: + # If request is throttled, the `Throttled` exception will be raised + await message.reply('Too many requests!') + else: + # Otherwise do something + await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") + + +@dp.message_handler(commands=['hi']) +@dp.throttled(lambda msg, loop, *args, **kwargs: loop.create_task(bot.send_message(msg.from_user.id, "Throttled")), + rate=5) +# loop is added to the function to run coroutines from it +async def say_hi(message: types.Message): + await message.answer("Hi") + + +# the on_throttled object can be either a regular function or coroutine +async def hello_throttled(*args, **kwargs): + # args will be the same as in the original handler + # kwargs will be updated with parameters given to .throttled (rate, key, user_id, chat_id) + print(f"hello_throttled was called with args={args} and kwargs={kwargs}") + message = args[0] # as message was the first argument in the original handler + await message.answer("Throttled") + + +@dp.message_handler(commands=['hello']) +@dp.throttled(hello_throttled, rate=4) +async def say_hello(message: types.Message): + await message.answer("Hello!") + + +@dp.message_handler(commands=['help']) +@dp.throttled(rate=5) +# nothing will happen if the handler will be throttled +async def help_handler(message: types.Message): + await message.answer('Help!') + +if __name__ == '__main__': + start_polling(dp, skip_updates=True)