update lock TTL

This commit is contained in:
Vitaly312 2026-03-01 11:21:44 +03:00
parent b7f2da391b
commit 10fa7315da
5 changed files with 62 additions and 13 deletions

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
DELAY_SEC = 1.0 DELAY_SEC = 1.0
TIMEOUT_SEC = 10 LOCK_TTL_SEC = 30
TTL_SEC = 600 TTL_SEC = 600
@ -53,13 +53,29 @@ class BaseMediaGroupAggregator(ABC):
message_ids.add(message.message_id) message_ids.add(message.message_id)
return result return result
async def get_current_time(self) -> float:
return time.time()
class RedisMediaGroupAggregator(BaseMediaGroupAggregator): class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
"""
Aggregates media groups in Redis.
"""
redis: "Redis" redis: "Redis"
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None: def __init__(
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
) -> None:
"""
:param ttl_sec: ttl for media group data in seconds
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
releasing until handler finished and too small to expire until telegram send retry if
handler failed.
"""
self.redis = redis self.redis = redis
self.ttl_sec = ttl_sec self.ttl_sec = ttl_sec
self.lock_ttl_sec = lock_ttl_sec
@staticmethod @staticmethod
def get_group_key(media_group_id: str) -> str: def get_group_key(media_group_id: str) -> str:
@ -74,8 +90,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return f"media_group:{media_group_id}:lock" return f"media_group:{media_group_id}:lock"
async def add_into_group(self, media_group_id: str, media: Message) -> int: async def add_into_group(self, media_group_id: str, media: Message) -> int:
current_time = await self.get_current_time()
async with self.redis.pipeline(transaction=True) as pipe: async with self.redis.pipeline(transaction=True) as pipe:
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec) pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json()) pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec) pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
res = await pipe.execute() res = await pipe.execute()
@ -85,7 +102,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return cast( return cast(
bool, bool,
await self.redis.set( await self.redis.set(
self.get_group_lock_key(media_group_id), "1", nx=True, ex=TIMEOUT_SEC self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
), ),
) )
@ -96,7 +113,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
result = await cast( result = await cast(
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1) Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
) )
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
async def delete_group(self, media_group_id: str) -> None: async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe: async with self.redis.pipeline(transaction=True) as pipe:
@ -110,6 +127,10 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return None return None
return float(result) return float(result)
async def get_current_time(self) -> float:
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
return seconds + microseconds / 1e6
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None: def __init__(self, ttl_sec: int = TTL_SEC) -> None:
@ -184,9 +205,11 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
if not await self.media_group_aggregator.acquire_lock(event.media_group_id): if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
return None return None
try: try:
last_message_time = time.time() last_message_time = await self.media_group_aggregator.get_current_time()
while True: while True:
delta = self.delay - (time.time() - last_message_time) delta = self.delay - (
await self.media_group_aggregator.get_current_time() - last_message_time
)
if delta <= 0: if delta <= 0:
album = await self.media_group_aggregator.get_group(event.media_group_id) album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album: if not album:

View file

@ -29,8 +29,10 @@ class MediaGroupFilter(Filter):
:param max_media_count: max count of media in the group, inclusively :param max_media_count: max count of media in the group, inclusively
""" """
if count is None: if count is None:
min_media_count = min_media_count or MIN_MEDIA_COUNT if min_media_count is None:
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT min_media_count = MIN_MEDIA_COUNT
if max_media_count is None:
max_media_count = max(DEFAULT_MAX_MEDIA_COUNT, min_media_count)
else: else:
if min_media_count is not None or max_media_count is not None: if min_media_count is not None or max_media_count is not None:
raise ValueError( raise ValueError(

View file

@ -51,7 +51,10 @@ By default each media in the group is processed separately.
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware` You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
to process media groups as one. If you do, only one message from the group will be processed, and updates for to process media groups as one. If you do, only one message from the group will be processed, and updates for
other messages with the same media group ID will be suppressed. other messages with the same media group ID will be suppressed. There are two options to store media groups:
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
You also can use :class:`aiogram.filters.media_group.MediaGroupFilter` You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
to filter media groups. to filter media groups.
@ -72,7 +75,7 @@ Usage
# use middleware # use middleware
@router.message( @router.message(
MediaGroupFilter(max_count=5), MediaGroupFilter(max_media_count=5),
F.caption == "album_caption" # other filters will be applied to the first message in the group F.caption == "album_caption" # other filters will be applied to the first message in the group
) )
async def start(message: Message, album: list[Message]): async def start(message: Message, album: list[Message]):
@ -93,3 +96,8 @@ References
:members: :members:
.. autoclass:: aiogram.filters.media_group.MediaGroupFilter .. autoclass:: aiogram.filters.media_group.MediaGroupFilter
:members: :members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
:members:
:special-members: __init__

View file

@ -96,7 +96,8 @@ class TestMediaGroupAggregatorMiddleware:
assert isinstance(first_message, Message) assert isinstance(first_message, Message)
assert first_message.message_id == 1 assert first_message.message_id == 1
async def test_skip_propagating_if_data_deleted(self): @pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
async def test_skip_propagating_if_data_deleted(self, deleted_object):
middleware = self.get_middleware() middleware = self.get_middleware()
counter = 0 counter = 0
@ -107,7 +108,10 @@ class TestMediaGroupAggregatorMiddleware:
task1 = await wait_until_func_call_sleep( task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {}) asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
) )
await middleware.media_group_aggregator.delete_group("42") if deleted_object == "album":
middleware.media_group_aggregator.groups.pop("42")
else:
middleware.media_group_aggregator.last_message_timers.pop("42")
await task1 await task1
assert counter == 0 assert counter == 0
@ -199,6 +203,16 @@ class TestMediaGroupAggregator:
assert await aggregator.get_group("42") == [] assert await aggregator.get_group("42") == []
assert await aggregator.get_group("24") == [new_msg] assert await aggregator.get_group("24") == [new_msg]
async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.time", return_value=1.1):
assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self):
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
assert await aggregator.get_current_time() == 1.123456
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1)) await aggregator.add_into_group("42", _get_message(1))

View file

@ -12,6 +12,8 @@ class TestMediaGroupFilter:
[ [
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT), ((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
((3,), 3, 3), ((3,), 3, 3),
((11,), 11, 11),
((None, 11, None), 11, 11),
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT), ((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
((None, None, 3), MIN_MEDIA_COUNT, 3), ((None, None, 3), MIN_MEDIA_COUNT, 3),
], ],