mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
update lock TTL
This commit is contained in:
parent
b7f2da391b
commit
10fa7315da
5 changed files with 62 additions and 13 deletions
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from redis.asyncio.client import Redis
|
||||
|
||||
DELAY_SEC = 1.0
|
||||
TIMEOUT_SEC = 10
|
||||
LOCK_TTL_SEC = 30
|
||||
TTL_SEC = 600
|
||||
|
||||
|
||||
|
|
@ -53,13 +53,29 @@ class BaseMediaGroupAggregator(ABC):
|
|||
message_ids.add(message.message_id)
|
||||
return result
|
||||
|
||||
async def get_current_time(self) -> float:
|
||||
return time.time()
|
||||
|
||||
|
||||
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||
"""
|
||||
Aggregates media groups in Redis.
|
||||
"""
|
||||
|
||||
redis: "Redis"
|
||||
|
||||
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
|
||||
def __init__(
|
||||
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
|
||||
) -> None:
|
||||
"""
|
||||
:param ttl_sec: ttl for media group data in seconds
|
||||
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
|
||||
releasing until handler finished and too small to expire until telegram send retry if
|
||||
handler failed.
|
||||
"""
|
||||
self.redis = redis
|
||||
self.ttl_sec = ttl_sec
|
||||
self.lock_ttl_sec = lock_ttl_sec
|
||||
|
||||
@staticmethod
|
||||
def get_group_key(media_group_id: str) -> str:
|
||||
|
|
@ -74,8 +90,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
return f"media_group:{media_group_id}:lock"
|
||||
|
||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||
current_time = await self.get_current_time()
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
|
||||
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
|
||||
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
||||
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
||||
res = await pipe.execute()
|
||||
|
|
@ -85,7 +102,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
return cast(
|
||||
bool,
|
||||
await self.redis.set(
|
||||
self.get_group_lock_key(media_group_id), "1", nx=True, ex=TIMEOUT_SEC
|
||||
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -96,7 +113,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
result = await cast(
|
||||
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||
)
|
||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
|
||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
|
||||
|
||||
async def delete_group(self, media_group_id: str) -> None:
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
|
|
@ -110,6 +127,10 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
return None
|
||||
return float(result)
|
||||
|
||||
async def get_current_time(self) -> float:
|
||||
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
|
||||
return seconds + microseconds / 1e6
|
||||
|
||||
|
||||
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||
|
|
@ -184,9 +205,11 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
|
||||
return None
|
||||
try:
|
||||
last_message_time = time.time()
|
||||
last_message_time = await self.media_group_aggregator.get_current_time()
|
||||
while True:
|
||||
delta = self.delay - (time.time() - last_message_time)
|
||||
delta = self.delay - (
|
||||
await self.media_group_aggregator.get_current_time() - last_message_time
|
||||
)
|
||||
if delta <= 0:
|
||||
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
||||
if not album:
|
||||
|
|
|
|||
|
|
@ -29,8 +29,10 @@ class MediaGroupFilter(Filter):
|
|||
:param max_media_count: max count of media in the group, inclusively
|
||||
"""
|
||||
if count is None:
|
||||
min_media_count = min_media_count or MIN_MEDIA_COUNT
|
||||
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT
|
||||
if min_media_count is None:
|
||||
min_media_count = MIN_MEDIA_COUNT
|
||||
if max_media_count is None:
|
||||
max_media_count = max(DEFAULT_MAX_MEDIA_COUNT, min_media_count)
|
||||
else:
|
||||
if min_media_count is not None or max_media_count is not None:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -51,7 +51,10 @@ By default each media in the group is processed separately.
|
|||
|
||||
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
|
||||
to process media groups as one. If you do, only one message from the group will be processed, and updates for
|
||||
other messages with the same media group ID will be suppressed.
|
||||
other messages with the same media group ID will be suppressed. There are two options to store media groups:
|
||||
|
||||
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
|
||||
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
|
||||
|
||||
You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
|
||||
to filter media groups.
|
||||
|
|
@ -72,7 +75,7 @@ Usage
|
|||
|
||||
# use middleware
|
||||
@router.message(
|
||||
MediaGroupFilter(max_count=5),
|
||||
MediaGroupFilter(max_media_count=5),
|
||||
F.caption == "album_caption" # other filters will be applied to the first message in the group
|
||||
)
|
||||
async def start(message: Message, album: list[Message]):
|
||||
|
|
@ -93,3 +96,8 @@ References
|
|||
:members:
|
||||
.. autoclass:: aiogram.filters.media_group.MediaGroupFilter
|
||||
:members:
|
||||
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
|
||||
:members:
|
||||
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
|
||||
:members:
|
||||
:special-members: __init__
|
||||
|
|
|
|||
|
|
@ -96,7 +96,8 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
assert isinstance(first_message, Message)
|
||||
assert first_message.message_id == 1
|
||||
|
||||
async def test_skip_propagating_if_data_deleted(self):
|
||||
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
|
||||
async def test_skip_propagating_if_data_deleted(self, deleted_object):
|
||||
middleware = self.get_middleware()
|
||||
counter = 0
|
||||
|
||||
|
|
@ -107,7 +108,10 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
task1 = await wait_until_func_call_sleep(
|
||||
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||
)
|
||||
await middleware.media_group_aggregator.delete_group("42")
|
||||
if deleted_object == "album":
|
||||
middleware.media_group_aggregator.groups.pop("42")
|
||||
else:
|
||||
middleware.media_group_aggregator.last_message_timers.pop("42")
|
||||
await task1
|
||||
assert counter == 0
|
||||
|
||||
|
|
@ -199,6 +203,16 @@ class TestMediaGroupAggregator:
|
|||
assert await aggregator.get_group("42") == []
|
||||
assert await aggregator.get_group("24") == [new_msg]
|
||||
|
||||
async def test_get_current_time_memory_aggregator(self):
|
||||
aggregator = MemoryMediaGroupAggregator()
|
||||
with mock.patch("time.time", return_value=1.1):
|
||||
assert await aggregator.get_current_time() == 1.1
|
||||
|
||||
async def test_get_current_time_redis_aggregator(self):
|
||||
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
|
||||
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
|
||||
assert await aggregator.get_current_time() == 1.123456
|
||||
|
||||
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||
assert await aggregator.get_last_message_time("42") is None
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ class TestMediaGroupFilter:
|
|||
[
|
||||
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
|
||||
((3,), 3, 3),
|
||||
((11,), 11, 11),
|
||||
((None, 11, None), 11, 11),
|
||||
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
|
||||
((None, None, 3), MIN_MEDIA_COUNT, 3),
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue