mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
update lock TTL
This commit is contained in:
parent
b7f2da391b
commit
10fa7315da
5 changed files with 62 additions and 13 deletions
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||||
from redis.asyncio.client import Redis
|
from redis.asyncio.client import Redis
|
||||||
|
|
||||||
DELAY_SEC = 1.0
|
DELAY_SEC = 1.0
|
||||||
TIMEOUT_SEC = 10
|
LOCK_TTL_SEC = 30
|
||||||
TTL_SEC = 600
|
TTL_SEC = 600
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,13 +53,29 @@ class BaseMediaGroupAggregator(ABC):
|
||||||
message_ids.add(message.message_id)
|
message_ids.add(message.message_id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def get_current_time(self) -> float:
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
|
"""
|
||||||
|
Aggregates media groups in Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
redis: "Redis"
|
redis: "Redis"
|
||||||
|
|
||||||
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
|
def __init__(
|
||||||
|
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param ttl_sec: ttl for media group data in seconds
|
||||||
|
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
|
||||||
|
releasing until handler finished and too small to expire until telegram send retry if
|
||||||
|
handler failed.
|
||||||
|
"""
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.ttl_sec = ttl_sec
|
self.ttl_sec = ttl_sec
|
||||||
|
self.lock_ttl_sec = lock_ttl_sec
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_group_key(media_group_id: str) -> str:
|
def get_group_key(media_group_id: str) -> str:
|
||||||
|
|
@ -74,8 +90,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
return f"media_group:{media_group_id}:lock"
|
return f"media_group:{media_group_id}:lock"
|
||||||
|
|
||||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||||
|
current_time = await self.get_current_time()
|
||||||
async with self.redis.pipeline(transaction=True) as pipe:
|
async with self.redis.pipeline(transaction=True) as pipe:
|
||||||
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
|
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
|
||||||
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
||||||
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
||||||
res = await pipe.execute()
|
res = await pipe.execute()
|
||||||
|
|
@ -85,7 +102,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
return cast(
|
return cast(
|
||||||
bool,
|
bool,
|
||||||
await self.redis.set(
|
await self.redis.set(
|
||||||
self.get_group_lock_key(media_group_id), "1", nx=True, ex=TIMEOUT_SEC
|
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -96,7 +113,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
result = await cast(
|
result = await cast(
|
||||||
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||||
)
|
)
|
||||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
|
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
|
||||||
|
|
||||||
async def delete_group(self, media_group_id: str) -> None:
|
async def delete_group(self, media_group_id: str) -> None:
|
||||||
async with self.redis.pipeline(transaction=True) as pipe:
|
async with self.redis.pipeline(transaction=True) as pipe:
|
||||||
|
|
@ -110,6 +127,10 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
return None
|
return None
|
||||||
return float(result)
|
return float(result)
|
||||||
|
|
||||||
|
async def get_current_time(self) -> float:
|
||||||
|
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
|
||||||
|
return seconds + microseconds / 1e6
|
||||||
|
|
||||||
|
|
||||||
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||||
|
|
@ -184,9 +205,11 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
||||||
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
|
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
last_message_time = time.time()
|
last_message_time = await self.media_group_aggregator.get_current_time()
|
||||||
while True:
|
while True:
|
||||||
delta = self.delay - (time.time() - last_message_time)
|
delta = self.delay - (
|
||||||
|
await self.media_group_aggregator.get_current_time() - last_message_time
|
||||||
|
)
|
||||||
if delta <= 0:
|
if delta <= 0:
|
||||||
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
||||||
if not album:
|
if not album:
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,10 @@ class MediaGroupFilter(Filter):
|
||||||
:param max_media_count: max count of media in the group, inclusively
|
:param max_media_count: max count of media in the group, inclusively
|
||||||
"""
|
"""
|
||||||
if count is None:
|
if count is None:
|
||||||
min_media_count = min_media_count or MIN_MEDIA_COUNT
|
if min_media_count is None:
|
||||||
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT
|
min_media_count = MIN_MEDIA_COUNT
|
||||||
|
if max_media_count is None:
|
||||||
|
max_media_count = max(DEFAULT_MAX_MEDIA_COUNT, min_media_count)
|
||||||
else:
|
else:
|
||||||
if min_media_count is not None or max_media_count is not None:
|
if min_media_count is not None or max_media_count is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,10 @@ By default each media in the group is processed separately.
|
||||||
|
|
||||||
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
|
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
|
||||||
to process media groups as one. If you do, only one message from the group will be processed, and updates for
|
to process media groups as one. If you do, only one message from the group will be processed, and updates for
|
||||||
other messages with the same media group ID will be suppressed.
|
other messages with the same media group ID will be suppressed. There are two options to store media groups:
|
||||||
|
|
||||||
|
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
|
||||||
|
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
|
||||||
|
|
||||||
You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
|
You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
|
||||||
to filter media groups.
|
to filter media groups.
|
||||||
|
|
@ -72,7 +75,7 @@ Usage
|
||||||
|
|
||||||
# use middleware
|
# use middleware
|
||||||
@router.message(
|
@router.message(
|
||||||
MediaGroupFilter(max_count=5),
|
MediaGroupFilter(max_media_count=5),
|
||||||
F.caption == "album_caption" # other filters will be applied to the first message in the group
|
F.caption == "album_caption" # other filters will be applied to the first message in the group
|
||||||
)
|
)
|
||||||
async def start(message: Message, album: list[Message]):
|
async def start(message: Message, album: list[Message]):
|
||||||
|
|
@ -93,3 +96,8 @@ References
|
||||||
:members:
|
:members:
|
||||||
.. autoclass:: aiogram.filters.media_group.MediaGroupFilter
|
.. autoclass:: aiogram.filters.media_group.MediaGroupFilter
|
||||||
:members:
|
:members:
|
||||||
|
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
|
||||||
|
:members:
|
||||||
|
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
|
||||||
|
:members:
|
||||||
|
:special-members: __init__
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,8 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
assert isinstance(first_message, Message)
|
assert isinstance(first_message, Message)
|
||||||
assert first_message.message_id == 1
|
assert first_message.message_id == 1
|
||||||
|
|
||||||
async def test_skip_propagating_if_data_deleted(self):
|
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
|
||||||
|
async def test_skip_propagating_if_data_deleted(self, deleted_object):
|
||||||
middleware = self.get_middleware()
|
middleware = self.get_middleware()
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
|
|
@ -107,7 +108,10 @@ class TestMediaGroupAggregatorMiddleware:
|
||||||
task1 = await wait_until_func_call_sleep(
|
task1 = await wait_until_func_call_sleep(
|
||||||
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||||
)
|
)
|
||||||
await middleware.media_group_aggregator.delete_group("42")
|
if deleted_object == "album":
|
||||||
|
middleware.media_group_aggregator.groups.pop("42")
|
||||||
|
else:
|
||||||
|
middleware.media_group_aggregator.last_message_timers.pop("42")
|
||||||
await task1
|
await task1
|
||||||
assert counter == 0
|
assert counter == 0
|
||||||
|
|
||||||
|
|
@ -199,6 +203,16 @@ class TestMediaGroupAggregator:
|
||||||
assert await aggregator.get_group("42") == []
|
assert await aggregator.get_group("42") == []
|
||||||
assert await aggregator.get_group("24") == [new_msg]
|
assert await aggregator.get_group("24") == [new_msg]
|
||||||
|
|
||||||
|
async def test_get_current_time_memory_aggregator(self):
|
||||||
|
aggregator = MemoryMediaGroupAggregator()
|
||||||
|
with mock.patch("time.time", return_value=1.1):
|
||||||
|
assert await aggregator.get_current_time() == 1.1
|
||||||
|
|
||||||
|
async def test_get_current_time_redis_aggregator(self):
|
||||||
|
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
|
||||||
|
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
|
||||||
|
assert await aggregator.get_current_time() == 1.123456
|
||||||
|
|
||||||
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||||
assert await aggregator.get_last_message_time("42") is None
|
assert await aggregator.get_last_message_time("42") is None
|
||||||
await aggregator.add_into_group("42", _get_message(1))
|
await aggregator.add_into_group("42", _get_message(1))
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ class TestMediaGroupFilter:
|
||||||
[
|
[
|
||||||
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
|
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
|
||||||
((3,), 3, 3),
|
((3,), 3, 3),
|
||||||
|
((11,), 11, 11),
|
||||||
|
((None, 11, None), 11, 11),
|
||||||
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
|
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
|
||||||
((None, None, 3), MIN_MEDIA_COUNT, 3),
|
((None, None, 3), MIN_MEDIA_COUNT, 3),
|
||||||
],
|
],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue