update lock TTL

This commit is contained in:
Vitaly312 2026-03-01 11:21:44 +03:00
parent b7f2da391b
commit 10fa7315da
5 changed files with 62 additions and 13 deletions

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
from redis.asyncio.client import Redis
DELAY_SEC = 1.0
TIMEOUT_SEC = 10
LOCK_TTL_SEC = 30
TTL_SEC = 600
@ -53,13 +53,29 @@ class BaseMediaGroupAggregator(ABC):
message_ids.add(message.message_id)
return result
async def get_current_time(self) -> float:
return time.time()
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
"""
Aggregates media groups in Redis.
"""
redis: "Redis"
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
def __init__(
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
) -> None:
"""
:param ttl_sec: ttl for media group data in seconds
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
releasing until handler finished and too small to expire until telegram send retry if
handler failed.
"""
self.redis = redis
self.ttl_sec = ttl_sec
self.lock_ttl_sec = lock_ttl_sec
@staticmethod
def get_group_key(media_group_id: str) -> str:
@ -74,8 +90,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return f"media_group:{media_group_id}:lock"
async def add_into_group(self, media_group_id: str, media: Message) -> int:
current_time = await self.get_current_time()
async with self.redis.pipeline(transaction=True) as pipe:
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
res = await pipe.execute()
@ -85,7 +102,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return cast(
bool,
await self.redis.set(
self.get_group_lock_key(media_group_id), "1", nx=True, ex=TIMEOUT_SEC
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
),
)
@ -96,7 +113,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
result = await cast(
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
)
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
@ -110,6 +127,10 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
return None
return float(result)
async def get_current_time(self) -> float:
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
return seconds + microseconds / 1e6
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
@ -184,9 +205,11 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
return None
try:
last_message_time = time.time()
last_message_time = await self.media_group_aggregator.get_current_time()
while True:
delta = self.delay - (time.time() - last_message_time)
delta = self.delay - (
await self.media_group_aggregator.get_current_time() - last_message_time
)
if delta <= 0:
album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album:

View file

@ -29,8 +29,10 @@ class MediaGroupFilter(Filter):
:param max_media_count: max count of media in the group, inclusively
"""
if count is None:
min_media_count = min_media_count or MIN_MEDIA_COUNT
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT
if min_media_count is None:
min_media_count = MIN_MEDIA_COUNT
if max_media_count is None:
max_media_count = max(DEFAULT_MAX_MEDIA_COUNT, min_media_count)
else:
if min_media_count is not None or max_media_count is not None:
raise ValueError(

View file

@ -51,7 +51,10 @@ By default each media in the group is processed separately.
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
to process media groups as one. If you do, only one message from the group will be processed, and updates for
other messages with the same media group ID will be suppressed.
other messages with the same media group ID will be suppressed. There are two options to store media groups:
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
to filter media groups.
@ -72,7 +75,7 @@ Usage
# use middleware
@router.message(
MediaGroupFilter(max_count=5),
MediaGroupFilter(max_media_count=5),
F.caption == "album_caption" # other filters will be applied to the first message in the group
)
async def start(message: Message, album: list[Message]):
@ -93,3 +96,8 @@ References
:members:
.. autoclass:: aiogram.filters.media_group.MediaGroupFilter
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
:members:
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
:members:
:special-members: __init__

View file

@ -96,7 +96,8 @@ class TestMediaGroupAggregatorMiddleware:
assert isinstance(first_message, Message)
assert first_message.message_id == 1
async def test_skip_propagating_if_data_deleted(self):
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
async def test_skip_propagating_if_data_deleted(self, deleted_object):
middleware = self.get_middleware()
counter = 0
@ -107,7 +108,10 @@ class TestMediaGroupAggregatorMiddleware:
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
)
await middleware.media_group_aggregator.delete_group("42")
if deleted_object == "album":
middleware.media_group_aggregator.groups.pop("42")
else:
middleware.media_group_aggregator.last_message_timers.pop("42")
await task1
assert counter == 0
@ -199,6 +203,16 @@ class TestMediaGroupAggregator:
assert await aggregator.get_group("42") == []
assert await aggregator.get_group("24") == [new_msg]
async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.time", return_value=1.1):
assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self):
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
assert await aggregator.get_current_time() == 1.123456
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1))

View file

@ -12,6 +12,8 @@ class TestMediaGroupFilter:
[
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
((3,), 3, 3),
((11,), 11, 11),
((None, 11, None), 11, 11),
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
((None, None, 3), MIN_MEDIA_COUNT, 3),
],