diff --git a/aiogram/utils/callback_data.py b/aiogram/utils/callback_data.py new file mode 100644 index 00000000..608d261d --- /dev/null +++ b/aiogram/utils/callback_data.py @@ -0,0 +1,136 @@ +""" +Callback data factory + +Usage: + Create instance of factory with prefix and element names: + >>> posts_query = CallbackData('post', 'post_id', 'action') + + Then you can generate callback data: + >>> posts_query.new('32feff9b-92fa-48d9-9d29-621dc713743a', action='view') + <<< post:32feff9b-92fa-48d9-9d29-621dc713743a:view + + Also you can generate filters: + >>> posts_query.filter(action='delete') + This filter can handle callback data by pattern: post:*:delete +""" +from __future__ import annotations + +import typing + +from aiogram import types +from aiogram.dispatcher.filters import Filter + + +class CallbackData: + """ + Callback data factory + """ + + def __init__(self, prefix, *parts, sep=':'): + if not isinstance(prefix, str): + raise TypeError(f"Prefix must be instance of str not {type(prefix).__name__}") + elif not prefix: + raise ValueError('Prefix can\'t be empty') + elif len(sep) != 1: + raise ValueError(f"Length of sep should be equals to 1") + elif sep in prefix: + raise ValueError(f"Symbol '{self.sep}' can't be used in prefix") + elif not parts: + raise TypeError('Parts is not passed!') + + self.prefix = prefix + self.sep = sep + + self._part_names = parts + + def new(self, *args, **kwargs) -> str: + """ + Generate callback data + + :param args: + :param kwargs: + :return: + """ + args = list(args) + + data = [self.prefix] + + for part in self._part_names: + value = kwargs.pop(part, None) + if not value: + try: + value = args.pop(0) + except IndexError: + raise ValueError(f"Value for '{part}' is not passed!") + + if not isinstance(value, str): + raise TypeError(f"Prefix must be instance of str not {type(value).__name__}") + elif not value: + raise ValueError(f"Value for part {part} can't be empty!'") + elif self.sep in value: + raise ValueError(f"Symbol bounded as separator can't be used in values of parts") + + data.append(value) + + if args or kwargs: + raise TypeError('Too many arguments is passed!') + + callback_data = self.sep.join(data) + if len(callback_data) > 64: + raise ValueError('Resulted callback data is too long!') + + return callback_data + + def parse(self, callback_data: str) -> typing.Dict[str, str]: + """ + Parse data from the callback data + + :param callback_data: + :return: + """ + prefix, *parts = callback_data.split(self.sep) + if prefix != self.prefix: + raise ValueError("Passed callback data can't be parsed with that prefix.") + elif len(parts) != len(self._part_names): + raise ValueError('Invalid parts count!') + + result = {'@': prefix} + result.update(zip(self._part_names, parts)) + return result + + def filter(self, **config) -> CallbackDataFilter: + """ + Generate filter + + :param config: + :return: + """ + for key in config.keys(): + if key not in self._part_names: + raise ValueError(f"Invalid field name '{key}'") + return CallbackDataFilter(self, config) + + +class CallbackDataFilter(Filter): + def __init__(self, factory: CallbackData, config: typing.Dict[str, str]): + self.config = config + self.factory = factory + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]): + raise ValueError('That filter can\'t be used in filters factory!') + + async def check(self, query: types.CallbackQuery): + try: + data = self.factory.parse(query.data) + except ValueError: + return False + else: + for key, value in self.config.items(): + if isinstance(value, (list, tuple, set)): + if data.get(key) not in value: + return False + else: + if value != data.get(key): + return False + return {'callback_data': data} diff --git a/examples/callback_data_factory.py b/examples/callback_data_factory.py new file mode 100644 index 00000000..f2432ef0 --- /dev/null +++ b/examples/callback_data_factory.py @@ -0,0 +1,117 @@ +import asyncio +import logging +import random +import uuid + +from aiogram import Bot, Dispatcher, executor, md, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.contrib.middlewares.logging import LoggingMiddleware +from aiogram.utils.callback_data import CallbackData +from aiogram.utils.exceptions import MessageNotModified, Throttled + +logging.basicConfig(level=logging.INFO) + +API_TOKEN = 'BOT TOKEN HERE' + +loop = asyncio.get_event_loop() +bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.HTML) +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) +dp.middleware.setup(LoggingMiddleware()) + +POSTS = { + str(uuid.uuid4()): { + 'title': f"Post {index}", + 'body': 'Lorem ipsum dolor sit amet, ' + 'consectetur adipiscing elit, ' + 'sed do eiusmod tempor incididunt ut ' + 'labore et dolore magna aliqua', + 'votes': random.randint(-2, 5) + } for index in range(5) +} + +posts_cb = CallbackData('post', 'id', 'action') # post:: + + +def get_keyboard() -> types.InlineKeyboardMarkup: + """ + Generate keyboard with list of posts + """ + markup = types.InlineKeyboardMarkup() + for post_id, post in POSTS.items(): + markup.add( + types.InlineKeyboardButton( + post['title'], + callback_data=posts_cb.new(id=post_id, action='view')) + ) + return markup + + +def format_post(post_id: str, post: dict) -> (str, types.InlineKeyboardMarkup): + text = f"{md.hbold(post['title'])}\n" \ + f"{md.quote_html(post['body'])}\n" \ + f"\n" \ + f"Votes: {post['votes']}" + + markup = types.InlineKeyboardMarkup() + markup.row( + types.InlineKeyboardButton('👍', callback_data=posts_cb.new(id=post_id, action='like')), + types.InlineKeyboardButton('👎', callback_data=posts_cb.new(id=post_id, action='unlike')), + ) + markup.add(types.InlineKeyboardButton('<< Back', callback_data=posts_cb.new(id='-', action='list'))) + return text, markup + + +@dp.message_handler(commands='start') +async def cmd_start(message: types.Message): + await message.reply('Posts', reply_markup=get_keyboard()) + + +@dp.callback_query_handler(posts_cb.filter(action='list')) +async def query_show_list(query: types.CallbackQuery): + await query.message.edit_text('Posts', reply_markup=get_keyboard()) + + +@dp.callback_query_handler(posts_cb.filter(action='view')) +async def query_view(query: types.CallbackQuery, callback_data: dict): + post_id = callback_data['id'] + + post = POSTS.get(post_id, None) + if not post: + return await query.answer('Unknown post!') + + text, markup = format_post(post_id, post) + await query.message.edit_text(text, reply_markup=markup) + + +@dp.callback_query_handler(posts_cb.filter(action=['like', 'unlike'])) +async def query_post_vote(query: types.CallbackQuery, callback_data: dict): + try: + await dp.throttle('vote', rate=1) + except Throttled: + return await query.answer('Too many requests.') + + post_id = callback_data['id'] + action = callback_data['action'] + + post = POSTS.get(post_id, None) + if not post: + return await query.answer('Unknown post!') + + if action == 'like': + post['votes'] += 1 + elif action == 'unlike': + post['votes'] -= 1 + + await query.answer('Voted.') + text, markup = format_post(post_id, post) + await query.message.edit_text(text, reply_markup=markup) + + +@dp.errors_handler(exception=MessageNotModified) +async def message_not_modified_handler(update, error): + return True + + +if __name__ == '__main__': + executor.start_polling(dp, loop=loop, skip_updates=True)