mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat: add TTL to MemoryMediaGroupAggregator
This commit is contained in:
parent
80db313c16
commit
998d6c3742
1 changed files with 22 additions and 1 deletions
|
|
@ -91,7 +91,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
await self.redis.delete(self.get_group_lock_key(media_group_id))
|
||||
|
||||
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||
result = await self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||
result = await cast(
|
||||
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||
)
|
||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
|
||||
|
||||
async def delete_group(self, media_group_id: str) -> None:
|
||||
|
|
@ -113,9 +115,25 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
self.last_message_timers: dict[str, float] = {}
|
||||
self.locks: dict[str, bool] = {}
|
||||
|
||||
def remove_expired_objects(self) -> None:
|
||||
expired_group_ids = []
|
||||
current_time = time.time()
|
||||
for group_id, last_message_time in self.last_message_timers.items():
|
||||
if current_time - last_message_time > TTL_SEC:
|
||||
expired_group_ids.append(group_id)
|
||||
else:
|
||||
break # the list is sorted in ascending order
|
||||
# because python 3.7+ save dict in insertion order
|
||||
for group_id in expired_group_ids:
|
||||
self.groups.pop(group_id, None)
|
||||
self.last_message_timers.pop(group_id, None)
|
||||
self.locks.pop(group_id, None)
|
||||
|
||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||
self.remove_expired_objects()
|
||||
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
|
||||
self.groups[media_group_id].append(media)
|
||||
self.last_message_timers.pop(media_group_id, None)
|
||||
self.last_message_timers[media_group_id] = time.time()
|
||||
return len(self.groups[media_group_id])
|
||||
|
||||
|
|
@ -145,6 +163,9 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
media_group_aggregator: BaseMediaGroupAggregator | None = None,
|
||||
delay: float = DELAY_SEC,
|
||||
) -> None:
|
||||
"""
|
||||
:param delay: delay between last received message in media group and processing it
|
||||
"""
|
||||
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
|
||||
self.delay = delay
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue