This commit is contained in:
Виталий 2026-04-05 18:52:48 +03:00 committed by GitHub
commit 17539c777a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 533 additions and 4 deletions

1
CHANGES/1697.feature.rst Normal file
View file

@ -0,0 +1 @@
Added ``MediaGroupAggregatorMiddleware`` to aggregate media groups into a single event with an album list in the handler data.

View file

@ -0,0 +1,246 @@
import asyncio
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Message, TelegramObject
if TYPE_CHECKING:
from redis.asyncio.client import Redis
DELAY_SEC = 1.0
LOCK_TTL_SEC = 30
TTL_SEC = 600
class BaseMediaGroupAggregator(ABC):
@abstractmethod
async def add_into_group(self, media_group_id: str, media: Message) -> int:
pass
@abstractmethod
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
pass
@abstractmethod
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
pass
@abstractmethod
async def get_group(self, media_group_id: str) -> list[Message]:
pass
@abstractmethod
async def delete_group(self, media_group_id: str) -> None:
pass
@abstractmethod
async def get_last_message_time(self, media_group_id: str) -> float | None:
pass
@staticmethod
def deduplicate_messages(messages: list[Message]) -> list[Message]:
message_ids = set()
result = []
for message in messages:
if message.message_id in message_ids:
continue
result.append(message)
message_ids.add(message.message_id)
return result
@abstractmethod
async def get_current_time(self) -> float:
pass
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
"""
Aggregates media groups in Redis.
"""
redis: "Redis"
def __init__(
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
) -> None:
"""
:param ttl_sec: ttl for media group data in seconds
:param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the
lock from expiring before the handler finishes, but small enough to expire before
Telegram retries a failed delivery.
"""
self.redis = redis
self.ttl_sec = ttl_sec
self.lock_ttl_sec = lock_ttl_sec
@staticmethod
def get_group_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:album"
@staticmethod
def get_last_message_time_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:last_message_time"
@staticmethod
def get_group_lock_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:lock"
async def add_into_group(self, media_group_id: str, media: Message) -> int:
current_time = await self.get_current_time()
async with self.redis.pipeline(transaction=True) as pipe:
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
res = await pipe.execute()
return cast(int, res[1])
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
return cast(
bool,
await self.redis.set(
self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec
),
)
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
release_script = (
'if redis.call("get", KEYS[1]) == ARGV[1] then '
'return redis.call("del", KEYS[1]) '
"else return 0 end"
)
await cast(
Awaitable[int],
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
)
async def get_group(self, media_group_id: str) -> list[Message]:
result = await cast(
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
)
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.delete(self.get_group_key(media_group_id))
pipe.delete(self.get_last_message_time_key(media_group_id))
await pipe.execute()
async def get_last_message_time(self, media_group_id: str) -> float | None:
result = await self.redis.get(self.get_last_message_time_key(media_group_id))
if result is None:
return None
return float(result)
async def get_current_time(self) -> float:
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
return seconds + microseconds / 1e6
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
self.groups: dict[str, list[Message]] = defaultdict(list)
self.last_message_timers: dict[str, float] = {}
self.locks: dict[str, str] = {}
self.ttl_sec = ttl_sec
def remove_expired_objects(self) -> None:
expired_group_ids = []
current_time = time.monotonic()
for group_id, last_message_time in self.last_message_timers.items():
if current_time - last_message_time > self.ttl_sec:
expired_group_ids.append(group_id)
else:
break # the list is sorted in ascending order
# because python 3.7+ save dict in insertion order
for group_id in expired_group_ids:
self.groups.pop(group_id, None)
self.last_message_timers.pop(group_id, None)
self.locks.pop(group_id, None)
async def add_into_group(self, media_group_id: str, media: Message) -> int:
self.remove_expired_objects()
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
self.groups[media_group_id].append(media)
self.last_message_timers.pop(media_group_id, None)
self.last_message_timers[media_group_id] = time.monotonic()
return len(self.groups[media_group_id])
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
if self.locks.get(media_group_id) is not None:
return False
self.locks[media_group_id] = lock_id
return True
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
if self.locks.get(media_group_id) == lock_id:
self.locks.pop(media_group_id)
async def get_group(self, media_group_id: str) -> list[Message]:
return self.groups.get(media_group_id, [])
async def delete_group(self, media_group_id: str) -> None:
self.groups.pop(media_group_id, None)
self.last_message_timers.pop(media_group_id, None)
async def get_last_message_time(self, media_group_id: str) -> float | None:
return self.last_message_timers.get(media_group_id)
async def get_current_time(self) -> float:
return time.monotonic()
class MediaGroupAggregatorMiddleware(BaseMiddleware):
def __init__(
self,
media_group_aggregator: BaseMediaGroupAggregator | None = None,
delay: float = DELAY_SEC,
) -> None:
"""
:param delay: delay between last received message in media group and processing it
"""
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
self.delay = delay
async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: dict[str, Any],
) -> Any:
if not isinstance(event, Message) or not event.media_group_id:
return await handler(event, data)
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
lock_id = str(uuid.uuid4())
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
return None
try:
while True:
last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not last_message_time:
return None
delta = self.delay - (
await self.media_group_aggregator.get_current_time() - last_message_time
)
if delta <= 0:
album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album:
return None
album = sorted(
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
key=lambda msg: msg.message_id,
)
data.update(album=album)
result = await handler(album[0], data)
await self.media_group_aggregator.delete_group(event.media_group_id)
return result
await asyncio.sleep(delta)
finally:
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)

View file

@ -1,8 +1,13 @@
===================
Media group builder
===================
===========
Media group
===========
This module provides a builder for media groups, it can be used to build media groups
This module provides tools for media groups.
Building media groups
=====================
Media group builder can be used to build media groups
for :class:`aiogram.types.input_media_photo.InputMediaPhoto`, :class:`aiogram.types.input_media_video.InputMediaVideo`,
:class:`aiogram.types.input_media_document.InputMediaDocument` and :class:`aiogram.types.input_media_audio.InputMediaAudio`.
@ -39,8 +44,58 @@ it will be used as ``caption`` for first media in group.
await bot.send_media_group(chat_id=chat_id, media=media_group.build())
Handling media groups
=====================
By default each media in the group is processed separately.
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
to process media groups as one. If you do, only one message from the group will be processed, and updates for
other messages with the same media group ID will be suppressed. There are two options to store media groups:
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
You also can use :class:`aiogram.filters.magic_data.MagicData` with ``F.album``
to filter media groups.
Usage
-----
.. code-block:: python
from aiogram import F
from aiogram.types import Message
# register middleware
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
from aiogram.filters import MagicData
router.message.outer_middleware(MediaGroupAggregatorMiddleware())
# use middleware
@router.message(
MagicData(F.album.len() <= 5),
F.caption == "album_caption" # other filters will be applied to the first message in the group
)
async def start(message: Message, album: list[Message]):
# message is the first media in this group
# album is list of all messages with the same mediaGroupId, including current message
await message.answer(
f"You sent {len(album)} media in the group. "
f"Media group ID: {message.media_group_id}. "
f"Album messages: {', '.join(str(m.message_id) for m in album)}"
)
References
==========
.. autoclass:: aiogram.utils.media_group.MediaGroupBuilder
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
:members:
:special-members: __init__

View file

@ -0,0 +1,227 @@
import asyncio
import time
from datetime import datetime
from typing import Any, Awaitable, Callable
from unittest import mock
import pytest
from redis.asyncio.client import Redis
from aiogram.dispatcher.middlewares.media_group import (
BaseMediaGroupAggregator,
MediaGroupAggregatorMiddleware,
MemoryMediaGroupAggregator,
RedisMediaGroupAggregator,
)
from aiogram.types import Chat, Message
def _get_message(message_id: int, **kwargs):
chat = Chat(id=1, type="private", title="Test")
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
start_sleep = asyncio.Event()
real_sleep = asyncio.sleep
async def mock_sleep(*args, **kwargs):
start_sleep.set()
await real_sleep(0)
with mock.patch("asyncio.sleep", mock_sleep):
task1 = func(*args, **kwargs)
await start_sleep.wait()
return task1
class TestMediaGroupAggregatorMiddleware:
def get_middleware(self):
return MediaGroupAggregatorMiddleware(delay=0.1)
async def test_skip_non_media_group(self):
is_called = False
async def next_handler(*args, **kwargs):
nonlocal is_called
is_called = True
await self.get_middleware()(next_handler, _get_message(1), {})
assert is_called
async def test_called_once_for_album(self):
middleware = self.get_middleware()
counter = 0
album = None
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, album
counter += 1
album = data.get("album")
await asyncio.gather(
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
)
assert album is not None
assert len(album) == 2
assert counter == 1
async def test_bot_object_saved(self, bot):
middleware = self.get_middleware()
event = album = None
async def next_handler(message: Message, data: dict[str, Any]):
nonlocal event, album
event = message
album = data.get("album")
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
assert event.bot is bot
assert all(msg.bot is bot for msg in album)
async def test_propagate_first_media_in_album(self):
middleware = self.get_middleware()
first_message = None
async def next_handler(message: Message, _):
nonlocal first_message
first_message = message
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
)
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
await task1
assert isinstance(first_message, Message)
assert first_message.message_id == 1
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
async def test_skip_propagating_if_data_deleted(self, deleted_object):
middleware = self.get_middleware()
counter = 0
async def next_handler(*args, **kwargs):
nonlocal counter
counter += 1
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
)
if deleted_object == "album":
middleware.media_group_aggregator.groups.pop("42")
else:
middleware.media_group_aggregator.last_message_timers.pop("42")
await task1
assert counter == 0
async def test_different_albums_non_interfere(self):
middleware = self.get_middleware()
counter = 0
albums = []
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, albums
counter += 1
albums.append(data.get("album"))
await asyncio.gather(
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
)
assert counter == 2
assert len(albums) == 2
async def test_retry_handling(self):
middleware = self.get_middleware()
album = None
async def failed_handler(*args, **kwargs):
raise RuntimeError("Failed")
async def working_handler(_, data: dict[str, Any]):
nonlocal album
album = data.get("album")
first_message = _get_message(1, media_group_id="42")
second_message = _get_message(2, media_group_id="42")
with pytest.raises(RuntimeError):
await asyncio.gather(
middleware(failed_handler, first_message, {}),
middleware(failed_handler, second_message, {}),
)
await middleware(working_handler, first_message, {})
assert len(album) == 2
def test_message_deduplication():
message_1, message_2 = _get_message(1), _get_message(2)
res = [message_1, message_2]
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
@pytest.fixture(params=["memory", "redis"], scope="function")
async def aggregator(request):
if request.param == "memory":
yield MemoryMediaGroupAggregator()
else:
redis = Redis.from_url(request.getfixturevalue("redis_server"))
yield RedisMediaGroupAggregator(redis)
keys = await redis.keys("media_group:*")
if keys:
await redis.delete(*keys)
await redis.aclose()
class TestMediaGroupAggregator:
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
msg1 = _get_message(1)
msg2 = _get_message(2)
assert await aggregator.add_into_group("42", msg1) == 1
assert await aggregator.add_into_group("42", msg2) == 2
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
msg1.message_id,
msg2.message_id,
}
await aggregator.delete_group("42")
assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
for i in ("key1", "key2"):
assert await aggregator.acquire_lock("42", i)
assert not await aggregator.acquire_lock("42", i)
await aggregator.release_lock("42", i)
async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator):
await aggregator.acquire_lock("42", "key1")
await aggregator.release_lock("42", "key2")
assert not await aggregator.acquire_lock("42", "key1")
async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1))
with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
new_msg = _get_message(2)
await aggregator.add_into_group("24", new_msg)
assert await aggregator.get_group("42") == []
assert await aggregator.get_group("24") == [new_msg]
async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.monotonic", return_value=1.1):
assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self):
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
assert await aggregator.get_current_time() == 1.123456
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1))
time_before_second_message = await aggregator.get_current_time()
assert await aggregator.get_last_message_time("42") <= time_before_second_message
await aggregator.add_into_group("42", _get_message(2))
assert await aggregator.get_last_message_time("42") >= time_before_second_message