mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Implement callback data factory.
This commit is contained in:
parent
2e2d9ea433
commit
2fc16b79ff
2 changed files with 253 additions and 0 deletions
136
aiogram/utils/callback_data.py
Normal file
136
aiogram/utils/callback_data.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""
|
||||||
|
Callback data factory
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Create instance of factory with prefix and element names:
|
||||||
|
>>> posts_query = CallbackData('post', 'post_id', 'action')
|
||||||
|
|
||||||
|
Then you can generate callback data:
|
||||||
|
>>> posts_query.new('32feff9b-92fa-48d9-9d29-621dc713743a', action='view')
|
||||||
|
<<< post:32feff9b-92fa-48d9-9d29-621dc713743a:view
|
||||||
|
|
||||||
|
Also you can generate filters:
|
||||||
|
>>> posts_query.filter(action='delete')
|
||||||
|
This filter can handle callback data by pattern: post:*:delete
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
|
from aiogram.dispatcher.filters import Filter
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackData:
|
||||||
|
"""
|
||||||
|
Callback data factory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, *parts, sep=':'):
|
||||||
|
if not isinstance(prefix, str):
|
||||||
|
raise TypeError(f"Prefix must be instance of str not {type(prefix).__name__}")
|
||||||
|
elif not prefix:
|
||||||
|
raise ValueError('Prefix can\'t be empty')
|
||||||
|
elif len(sep) != 1:
|
||||||
|
raise ValueError(f"Length of sep should be equals to 1")
|
||||||
|
elif sep in prefix:
|
||||||
|
raise ValueError(f"Symbol '{self.sep}' can't be used in prefix")
|
||||||
|
elif not parts:
|
||||||
|
raise TypeError('Parts is not passed!')
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
self.sep = sep
|
||||||
|
|
||||||
|
self._part_names = parts
|
||||||
|
|
||||||
|
def new(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Generate callback data
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
args = list(args)
|
||||||
|
|
||||||
|
data = [self.prefix]
|
||||||
|
|
||||||
|
for part in self._part_names:
|
||||||
|
value = kwargs.pop(part, None)
|
||||||
|
if not value:
|
||||||
|
try:
|
||||||
|
value = args.pop(0)
|
||||||
|
except IndexError:
|
||||||
|
raise ValueError(f"Value for '{part}' is not passed!")
|
||||||
|
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise TypeError(f"Prefix must be instance of str not {type(value).__name__}")
|
||||||
|
elif not value:
|
||||||
|
raise ValueError(f"Value for part {part} can't be empty!'")
|
||||||
|
elif self.sep in value:
|
||||||
|
raise ValueError(f"Symbol bounded as separator can't be used in values of parts")
|
||||||
|
|
||||||
|
data.append(value)
|
||||||
|
|
||||||
|
if args or kwargs:
|
||||||
|
raise TypeError('Too many arguments is passed!')
|
||||||
|
|
||||||
|
callback_data = self.sep.join(data)
|
||||||
|
if len(callback_data) > 64:
|
||||||
|
raise ValueError('Resulted callback data is too long!')
|
||||||
|
|
||||||
|
return callback_data
|
||||||
|
|
||||||
|
def parse(self, callback_data: str) -> typing.Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parse data from the callback data
|
||||||
|
|
||||||
|
:param callback_data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prefix, *parts = callback_data.split(self.sep)
|
||||||
|
if prefix != self.prefix:
|
||||||
|
raise ValueError("Passed callback data can't be parsed with that prefix.")
|
||||||
|
elif len(parts) != len(self._part_names):
|
||||||
|
raise ValueError('Invalid parts count!')
|
||||||
|
|
||||||
|
result = {'@': prefix}
|
||||||
|
result.update(zip(self._part_names, parts))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def filter(self, **config) -> CallbackDataFilter:
|
||||||
|
"""
|
||||||
|
Generate filter
|
||||||
|
|
||||||
|
:param config:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for key in config.keys():
|
||||||
|
if key not in self._part_names:
|
||||||
|
raise ValueError(f"Invalid field name '{key}'")
|
||||||
|
return CallbackDataFilter(self, config)
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackDataFilter(Filter):
|
||||||
|
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
|
||||||
|
self.config = config
|
||||||
|
self.factory = factory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||||
|
raise ValueError('That filter can\'t be used in filters factory!')
|
||||||
|
|
||||||
|
async def check(self, query: types.CallbackQuery):
|
||||||
|
try:
|
||||||
|
data = self.factory.parse(query.data)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
for key, value in self.config.items():
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
if data.get(key) not in value:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if value != data.get(key):
|
||||||
|
return False
|
||||||
|
return {'callback_data': data}
|
||||||
117
examples/callback_data_factory.py
Normal file
117
examples/callback_data_factory.py
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from aiogram import Bot, Dispatcher, executor, md, types
|
||||||
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
|
from aiogram.contrib.middlewares.logging import LoggingMiddleware
|
||||||
|
from aiogram.utils.callback_data import CallbackData
|
||||||
|
from aiogram.utils.exceptions import MessageNotModified, Throttled
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.HTML)
|
||||||
|
storage = MemoryStorage()
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
dp.middleware.setup(LoggingMiddleware())
|
||||||
|
|
||||||
|
POSTS = {
|
||||||
|
str(uuid.uuid4()): {
|
||||||
|
'title': f"Post {index}",
|
||||||
|
'body': 'Lorem ipsum dolor sit amet, '
|
||||||
|
'consectetur adipiscing elit, '
|
||||||
|
'sed do eiusmod tempor incididunt ut '
|
||||||
|
'labore et dolore magna aliqua',
|
||||||
|
'votes': random.randint(-2, 5)
|
||||||
|
} for index in range(5)
|
||||||
|
}
|
||||||
|
|
||||||
|
posts_cb = CallbackData('post', 'id', 'action') # post:<id>:<action>
|
||||||
|
|
||||||
|
|
||||||
|
def get_keyboard() -> types.InlineKeyboardMarkup:
|
||||||
|
"""
|
||||||
|
Generate keyboard with list of posts
|
||||||
|
"""
|
||||||
|
markup = types.InlineKeyboardMarkup()
|
||||||
|
for post_id, post in POSTS.items():
|
||||||
|
markup.add(
|
||||||
|
types.InlineKeyboardButton(
|
||||||
|
post['title'],
|
||||||
|
callback_data=posts_cb.new(id=post_id, action='view'))
|
||||||
|
)
|
||||||
|
return markup
|
||||||
|
|
||||||
|
|
||||||
|
def format_post(post_id: str, post: dict) -> (str, types.InlineKeyboardMarkup):
|
||||||
|
text = f"{md.hbold(post['title'])}\n" \
|
||||||
|
f"{md.quote_html(post['body'])}\n" \
|
||||||
|
f"\n" \
|
||||||
|
f"Votes: {post['votes']}"
|
||||||
|
|
||||||
|
markup = types.InlineKeyboardMarkup()
|
||||||
|
markup.row(
|
||||||
|
types.InlineKeyboardButton('👍', callback_data=posts_cb.new(id=post_id, action='like')),
|
||||||
|
types.InlineKeyboardButton('👎', callback_data=posts_cb.new(id=post_id, action='unlike')),
|
||||||
|
)
|
||||||
|
markup.add(types.InlineKeyboardButton('<< Back', callback_data=posts_cb.new(id='-', action='list')))
|
||||||
|
return text, markup
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands='start')
|
||||||
|
async def cmd_start(message: types.Message):
|
||||||
|
await message.reply('Posts', reply_markup=get_keyboard())
|
||||||
|
|
||||||
|
|
||||||
|
@dp.callback_query_handler(posts_cb.filter(action='list'))
|
||||||
|
async def query_show_list(query: types.CallbackQuery):
|
||||||
|
await query.message.edit_text('Posts', reply_markup=get_keyboard())
|
||||||
|
|
||||||
|
|
||||||
|
@dp.callback_query_handler(posts_cb.filter(action='view'))
|
||||||
|
async def query_view(query: types.CallbackQuery, callback_data: dict):
|
||||||
|
post_id = callback_data['id']
|
||||||
|
|
||||||
|
post = POSTS.get(post_id, None)
|
||||||
|
if not post:
|
||||||
|
return await query.answer('Unknown post!')
|
||||||
|
|
||||||
|
text, markup = format_post(post_id, post)
|
||||||
|
await query.message.edit_text(text, reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.callback_query_handler(posts_cb.filter(action=['like', 'unlike']))
|
||||||
|
async def query_post_vote(query: types.CallbackQuery, callback_data: dict):
|
||||||
|
try:
|
||||||
|
await dp.throttle('vote', rate=1)
|
||||||
|
except Throttled:
|
||||||
|
return await query.answer('Too many requests.')
|
||||||
|
|
||||||
|
post_id = callback_data['id']
|
||||||
|
action = callback_data['action']
|
||||||
|
|
||||||
|
post = POSTS.get(post_id, None)
|
||||||
|
if not post:
|
||||||
|
return await query.answer('Unknown post!')
|
||||||
|
|
||||||
|
if action == 'like':
|
||||||
|
post['votes'] += 1
|
||||||
|
elif action == 'unlike':
|
||||||
|
post['votes'] -= 1
|
||||||
|
|
||||||
|
await query.answer('Voted.')
|
||||||
|
text, markup = format_post(post_id, post)
|
||||||
|
await query.message.edit_text(text, reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.errors_handler(exception=MessageNotModified)
|
||||||
|
async def message_not_modified_handler(update, error):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue