mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
233 lines
8.8 KiB
Python
233 lines
8.8 KiB
Python
import asyncio
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import TYPE_CHECKING, Any, cast
|
|
|
|
from aiogram import Bot
|
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
|
from aiogram.types import Message, TelegramObject
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio.client import Redis
|
|
|
|
DELAY_SEC = 1.0
|
|
LOCK_TTL_SEC = 30
|
|
TTL_SEC = 600
|
|
|
|
|
|
class BaseMediaGroupAggregator(ABC):
|
|
@abstractmethod
|
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def acquire_lock(self, media_group_id: str) -> bool:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def release_lock(self, media_group_id: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_group(self, media_group_id: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def deduplicate_messages(messages: list[Message]) -> list[Message]:
|
|
message_ids = set()
|
|
result = []
|
|
for message in messages:
|
|
if message.message_id in message_ids:
|
|
continue
|
|
result.append(message)
|
|
message_ids.add(message.message_id)
|
|
return result
|
|
|
|
async def get_current_time(self) -> float:
|
|
return time.time()
|
|
|
|
|
|
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|
"""
|
|
Aggregates media groups in Redis.
|
|
"""
|
|
|
|
redis: "Redis"
|
|
|
|
def __init__(
|
|
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
|
|
) -> None:
|
|
"""
|
|
:param ttl_sec: ttl for media group data in seconds
|
|
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
|
|
releasing until handler finished and too small to expire until telegram send retry if
|
|
handler failed.
|
|
"""
|
|
self.redis = redis
|
|
self.ttl_sec = ttl_sec
|
|
self.lock_ttl_sec = lock_ttl_sec
|
|
|
|
@staticmethod
|
|
def get_group_key(media_group_id: str) -> str:
|
|
return f"media_group:{media_group_id}:album"
|
|
|
|
@staticmethod
|
|
def get_last_message_time_key(media_group_id: str) -> str:
|
|
return f"media_group:{media_group_id}:last_message_time"
|
|
|
|
@staticmethod
|
|
def get_group_lock_key(media_group_id: str) -> str:
|
|
return f"media_group:{media_group_id}:lock"
|
|
|
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
|
current_time = await self.get_current_time()
|
|
async with self.redis.pipeline(transaction=True) as pipe:
|
|
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
|
|
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
|
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
|
res = await pipe.execute()
|
|
return cast(int, res[1])
|
|
|
|
async def acquire_lock(self, media_group_id: str) -> bool:
|
|
return cast(
|
|
bool,
|
|
await self.redis.set(
|
|
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
|
|
),
|
|
)
|
|
|
|
async def release_lock(self, media_group_id: str) -> None:
|
|
await self.redis.delete(self.get_group_lock_key(media_group_id))
|
|
|
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
|
result = await cast(
|
|
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
|
)
|
|
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
|
|
|
|
async def delete_group(self, media_group_id: str) -> None:
|
|
async with self.redis.pipeline(transaction=True) as pipe:
|
|
pipe.delete(self.get_group_key(media_group_id))
|
|
pipe.delete(self.get_last_message_time_key(media_group_id))
|
|
await pipe.execute()
|
|
|
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
|
result = await self.redis.get(self.get_last_message_time_key(media_group_id))
|
|
if result is None:
|
|
return None
|
|
return float(result)
|
|
|
|
async def get_current_time(self) -> float:
|
|
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
|
|
return seconds + microseconds / 1e6
|
|
|
|
|
|
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
|
self.groups: dict[str, list[Message]] = defaultdict(list)
|
|
self.last_message_timers: dict[str, float] = {}
|
|
self.locks: dict[str, bool] = {}
|
|
self.ttl_sec = ttl_sec
|
|
|
|
def remove_expired_objects(self) -> None:
|
|
expired_group_ids = []
|
|
current_time = time.time()
|
|
for group_id, last_message_time in self.last_message_timers.items():
|
|
if current_time - last_message_time > self.ttl_sec:
|
|
expired_group_ids.append(group_id)
|
|
else:
|
|
break # the list is sorted in ascending order
|
|
# because python 3.7+ save dict in insertion order
|
|
for group_id in expired_group_ids:
|
|
self.groups.pop(group_id, None)
|
|
self.last_message_timers.pop(group_id, None)
|
|
self.locks.pop(group_id, None)
|
|
|
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
|
self.remove_expired_objects()
|
|
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
|
|
self.groups[media_group_id].append(media)
|
|
self.last_message_timers.pop(media_group_id, None)
|
|
self.last_message_timers[media_group_id] = time.time()
|
|
return len(self.groups[media_group_id])
|
|
|
|
async def acquire_lock(self, media_group_id: str) -> bool:
|
|
if self.locks.get(media_group_id):
|
|
return False
|
|
self.locks[media_group_id] = True
|
|
return True
|
|
|
|
async def release_lock(self, media_group_id: str) -> None:
|
|
self.locks.pop(media_group_id, None)
|
|
|
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
|
return self.groups.get(media_group_id, [])
|
|
|
|
async def delete_group(self, media_group_id: str) -> None:
|
|
self.groups.pop(media_group_id, None)
|
|
self.last_message_timers.pop(media_group_id, None)
|
|
|
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
|
return self.last_message_timers.get(media_group_id)
|
|
|
|
|
|
class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|
def __init__(
|
|
self,
|
|
media_group_aggregator: BaseMediaGroupAggregator | None = None,
|
|
delay: float = DELAY_SEC,
|
|
) -> None:
|
|
"""
|
|
:param delay: delay between last received message in media group and processing it
|
|
"""
|
|
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
|
|
self.delay = delay
|
|
|
|
async def __call__(
|
|
self,
|
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
|
event: TelegramObject,
|
|
data: dict[str, Any],
|
|
) -> Any:
|
|
if not isinstance(event, Message) or not event.media_group_id:
|
|
return await handler(event, data)
|
|
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
|
|
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
|
|
return None
|
|
try:
|
|
last_message_time = await self.media_group_aggregator.get_current_time()
|
|
while True:
|
|
delta = self.delay - (
|
|
await self.media_group_aggregator.get_current_time() - last_message_time
|
|
)
|
|
if delta <= 0:
|
|
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
|
if not album:
|
|
return None
|
|
album = sorted(
|
|
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
|
|
key=lambda msg: msg.message_id,
|
|
)
|
|
data.update(album=album)
|
|
result = await handler(album[0], data)
|
|
await self.media_group_aggregator.delete_group(event.media_group_id)
|
|
return result
|
|
await asyncio.sleep(delta)
|
|
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
|
|
event.media_group_id
|
|
)
|
|
if not new_last_message_time:
|
|
return None
|
|
last_message_time = new_last_message_time
|
|
finally:
|
|
await self.media_group_aggregator.release_lock(event.media_group_id)
|