mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
fix lock releasing
This commit is contained in:
parent
10fa7315da
commit
80b4abd3b5
2 changed files with 47 additions and 29 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
|
@ -20,27 +21,27 @@ TTL_SEC = 600
|
|||
class BaseMediaGroupAggregator(ABC):
|
||||
@abstractmethod
|
||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_lock(self, media_group_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def release_lock(self, media_group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_group(self, media_group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def deduplicate_messages(messages: list[Message]) -> list[Message]:
|
||||
|
|
@ -53,8 +54,9 @@ class BaseMediaGroupAggregator(ABC):
|
|||
message_ids.add(message.message_id)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def get_current_time(self) -> float:
|
||||
return time.time()
|
||||
pass
|
||||
|
||||
|
||||
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||
|
|
@ -98,16 +100,24 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
res = await pipe.execute()
|
||||
return cast(int, res[1])
|
||||
|
||||
async def acquire_lock(self, media_group_id: str) -> bool:
|
||||
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||
return cast(
|
||||
bool,
|
||||
await self.redis.set(
|
||||
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
|
||||
self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec
|
||||
),
|
||||
)
|
||||
|
||||
async def release_lock(self, media_group_id: str) -> None:
|
||||
await self.redis.delete(self.get_group_lock_key(media_group_id))
|
||||
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||
release_script = (
|
||||
'if redis.call("get", KEYS[1]) == ARGV[1] then '
|
||||
'return redis.call("del", KEYS[1]) '
|
||||
"else return 0 end"
|
||||
)
|
||||
await cast(
|
||||
Awaitable[str],
|
||||
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
|
||||
)
|
||||
|
||||
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||
result = await cast(
|
||||
|
|
@ -136,12 +146,12 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||
self.groups: dict[str, list[Message]] = defaultdict(list)
|
||||
self.last_message_timers: dict[str, float] = {}
|
||||
self.locks: dict[str, bool] = {}
|
||||
self.locks: dict[str, str] = {}
|
||||
self.ttl_sec = ttl_sec
|
||||
|
||||
def remove_expired_objects(self) -> None:
|
||||
expired_group_ids = []
|
||||
current_time = time.time()
|
||||
current_time = time.monotonic()
|
||||
for group_id, last_message_time in self.last_message_timers.items():
|
||||
if current_time - last_message_time > self.ttl_sec:
|
||||
expired_group_ids.append(group_id)
|
||||
|
|
@ -158,17 +168,18 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
|
||||
self.groups[media_group_id].append(media)
|
||||
self.last_message_timers.pop(media_group_id, None)
|
||||
self.last_message_timers[media_group_id] = time.time()
|
||||
self.last_message_timers[media_group_id] = time.monotonic()
|
||||
return len(self.groups[media_group_id])
|
||||
|
||||
async def acquire_lock(self, media_group_id: str) -> bool:
|
||||
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||
if self.locks.get(media_group_id):
|
||||
return False
|
||||
self.locks[media_group_id] = True
|
||||
self.locks[media_group_id] = lock_id
|
||||
return True
|
||||
|
||||
async def release_lock(self, media_group_id: str) -> None:
|
||||
self.locks.pop(media_group_id, None)
|
||||
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||
if self.locks.get(media_group_id) == lock_id:
|
||||
self.locks.pop(media_group_id)
|
||||
|
||||
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||
return self.groups.get(media_group_id, [])
|
||||
|
|
@ -180,6 +191,9 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
||||
return self.last_message_timers.get(media_group_id)
|
||||
|
||||
async def get_current_time(self) -> float:
|
||||
return time.monotonic()
|
||||
|
||||
|
||||
class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
|
|
@ -202,7 +216,8 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
if not isinstance(event, Message) or not event.media_group_id:
|
||||
return await handler(event, data)
|
||||
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
|
||||
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
|
||||
lock_id = str(uuid.uuid4())
|
||||
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
|
||||
return None
|
||||
try:
|
||||
last_message_time = await self.media_group_aggregator.get_current_time()
|
||||
|
|
@ -230,4 +245,4 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
return None
|
||||
last_message_time = new_last_message_time
|
||||
finally:
|
||||
await self.media_group_aggregator.release_lock(event.media_group_id)
|
||||
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)
|
||||
|
|
|
|||
|
|
@ -189,15 +189,18 @@ class TestMediaGroupAggregator:
|
|||
assert await aggregator.get_group("42") == []
|
||||
|
||||
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
|
||||
for _ in range(2):
|
||||
assert await aggregator.acquire_lock("42")
|
||||
assert not await aggregator.acquire_lock("42")
|
||||
await aggregator.release_lock("42")
|
||||
await aggregator.acquire_lock("42", "key1")
|
||||
assert not await aggregator.acquire_lock("42", "key2")
|
||||
await aggregator.release_lock("42", "key1")
|
||||
for i in ("key2", "key3"):
|
||||
assert await aggregator.acquire_lock("42", i)
|
||||
assert not await aggregator.acquire_lock("42", i)
|
||||
await aggregator.release_lock("42", i)
|
||||
|
||||
async def test_expired_objects_removed(self):
|
||||
aggregator = MemoryMediaGroupAggregator()
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1):
|
||||
with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
|
||||
new_msg = _get_message(2)
|
||||
await aggregator.add_into_group("24", new_msg)
|
||||
assert await aggregator.get_group("42") == []
|
||||
|
|
@ -205,7 +208,7 @@ class TestMediaGroupAggregator:
|
|||
|
||||
async def test_get_current_time_memory_aggregator(self):
|
||||
aggregator = MemoryMediaGroupAggregator()
|
||||
with mock.patch("time.time", return_value=1.1):
|
||||
with mock.patch("time.monotonic", return_value=1.1):
|
||||
assert await aggregator.get_current_time() == 1.1
|
||||
|
||||
async def test_get_current_time_redis_aggregator(self):
|
||||
|
|
@ -216,7 +219,7 @@ class TestMediaGroupAggregator:
|
|||
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||
assert await aggregator.get_last_message_time("42") is None
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
time_before_second_message = time.time()
|
||||
time_before_second_message = await aggregator.get_current_time()
|
||||
assert await aggregator.get_last_message_time("42") <= time_before_second_message
|
||||
await aggregator.add_into_group("42", _get_message(2))
|
||||
assert await aggregator.get_last_message_time("42") >= time_before_second_message
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue