aiogram/aiogram/dispatcher/middlewares/media_group.py
2026-03-01 12:01:04 +03:00

233 lines
8.8 KiB
Python

import asyncio
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Message, TelegramObject
if TYPE_CHECKING:
from redis.asyncio.client import Redis
DELAY_SEC = 1.0
LOCK_TTL_SEC = 30
TTL_SEC = 600
class BaseMediaGroupAggregator(ABC):
@abstractmethod
async def add_into_group(self, media_group_id: str, media: Message) -> int:
raise NotImplementedError
@abstractmethod
async def acquire_lock(self, media_group_id: str) -> bool:
raise NotImplementedError
@abstractmethod
async def release_lock(self, media_group_id: str) -> None:
raise NotImplementedError
@abstractmethod
async def get_group(self, media_group_id: str) -> list[Message]:
raise NotImplementedError
@abstractmethod
async def delete_group(self, media_group_id: str) -> None:
raise NotImplementedError
@abstractmethod
async def get_last_message_time(self, media_group_id: str) -> float | None:
raise NotImplementedError
@staticmethod
def deduplicate_messages(messages: list[Message]) -> list[Message]:
message_ids = set()
result = []
for message in messages:
if message.message_id in message_ids:
continue
result.append(message)
message_ids.add(message.message_id)
return result
async def get_current_time(self) -> float:
return time.time()
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
"""
Aggregates media groups in Redis.
"""
redis: "Redis"
def __init__(
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
) -> None:
"""
:param ttl_sec: ttl for media group data in seconds
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
releasing until handler finished and too small to expire until telegram send retry if
handler failed.
"""
self.redis = redis
self.ttl_sec = ttl_sec
self.lock_ttl_sec = lock_ttl_sec
@staticmethod
def get_group_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:album"
@staticmethod
def get_last_message_time_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:last_message_time"
@staticmethod
def get_group_lock_key(media_group_id: str) -> str:
return f"media_group:{media_group_id}:lock"
async def add_into_group(self, media_group_id: str, media: Message) -> int:
current_time = await self.get_current_time()
async with self.redis.pipeline(transaction=True) as pipe:
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
res = await pipe.execute()
return cast(int, res[1])
async def acquire_lock(self, media_group_id: str) -> bool:
return cast(
bool,
await self.redis.set(
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
),
)
async def release_lock(self, media_group_id: str) -> None:
await self.redis.delete(self.get_group_lock_key(media_group_id))
async def get_group(self, media_group_id: str) -> list[Message]:
result = await cast(
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
)
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.delete(self.get_group_key(media_group_id))
pipe.delete(self.get_last_message_time_key(media_group_id))
await pipe.execute()
async def get_last_message_time(self, media_group_id: str) -> float | None:
result = await self.redis.get(self.get_last_message_time_key(media_group_id))
if result is None:
return None
return float(result)
async def get_current_time(self) -> float:
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
return seconds + microseconds / 1e6
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
self.groups: dict[str, list[Message]] = defaultdict(list)
self.last_message_timers: dict[str, float] = {}
self.locks: dict[str, bool] = {}
self.ttl_sec = ttl_sec
def remove_expired_objects(self) -> None:
expired_group_ids = []
current_time = time.time()
for group_id, last_message_time in self.last_message_timers.items():
if current_time - last_message_time > self.ttl_sec:
expired_group_ids.append(group_id)
else:
break # the list is sorted in ascending order
# because python 3.7+ save dict in insertion order
for group_id in expired_group_ids:
self.groups.pop(group_id, None)
self.last_message_timers.pop(group_id, None)
self.locks.pop(group_id, None)
async def add_into_group(self, media_group_id: str, media: Message) -> int:
self.remove_expired_objects()
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
self.groups[media_group_id].append(media)
self.last_message_timers.pop(media_group_id, None)
self.last_message_timers[media_group_id] = time.time()
return len(self.groups[media_group_id])
async def acquire_lock(self, media_group_id: str) -> bool:
if self.locks.get(media_group_id):
return False
self.locks[media_group_id] = True
return True
async def release_lock(self, media_group_id: str) -> None:
self.locks.pop(media_group_id, None)
async def get_group(self, media_group_id: str) -> list[Message]:
return self.groups.get(media_group_id, [])
async def delete_group(self, media_group_id: str) -> None:
self.groups.pop(media_group_id, None)
self.last_message_timers.pop(media_group_id, None)
async def get_last_message_time(self, media_group_id: str) -> float | None:
return self.last_message_timers.get(media_group_id)
class MediaGroupAggregatorMiddleware(BaseMiddleware):
def __init__(
self,
media_group_aggregator: BaseMediaGroupAggregator | None = None,
delay: float = DELAY_SEC,
) -> None:
"""
:param delay: delay between last received message in media group and processing it
"""
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
self.delay = delay
async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: dict[str, Any],
) -> Any:
if not isinstance(event, Message) or not event.media_group_id:
return await handler(event, data)
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
return None
try:
last_message_time = await self.media_group_aggregator.get_current_time()
while True:
delta = self.delay - (
await self.media_group_aggregator.get_current_time() - last_message_time
)
if delta <= 0:
album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album:
return None
album = sorted(
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
key=lambda msg: msg.message_id,
)
data.update(album=album)
result = await handler(album[0], data)
await self.media_group_aggregator.delete_group(event.media_group_id)
return result
await asyncio.sleep(delta)
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not new_last_message_time:
return None
last_message_time = new_last_message_time
finally:
await self.media_group_aggregator.release_lock(event.media_group_id)