fix lock releasing

This commit is contained in:
Vitaly312 2026-03-01 15:25:59 +03:00
parent 10fa7315da
commit 80b4abd3b5
2 changed files with 47 additions and 29 deletions

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import time import time
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
@ -20,27 +21,27 @@ TTL_SEC = 600
class BaseMediaGroupAggregator(ABC): class BaseMediaGroupAggregator(ABC):
@abstractmethod @abstractmethod
async def add_into_group(self, media_group_id: str, media: Message) -> int: async def add_into_group(self, media_group_id: str, media: Message) -> int:
raise NotImplementedError pass
@abstractmethod @abstractmethod
async def acquire_lock(self, media_group_id: str) -> bool: async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
raise NotImplementedError pass
@abstractmethod @abstractmethod
async def release_lock(self, media_group_id: str) -> None: async def release_lock(self, media_group_id: str, lock_id: str) -> None:
raise NotImplementedError pass
@abstractmethod @abstractmethod
async def get_group(self, media_group_id: str) -> list[Message]: async def get_group(self, media_group_id: str) -> list[Message]:
raise NotImplementedError pass
@abstractmethod @abstractmethod
async def delete_group(self, media_group_id: str) -> None: async def delete_group(self, media_group_id: str) -> None:
raise NotImplementedError pass
@abstractmethod @abstractmethod
async def get_last_message_time(self, media_group_id: str) -> float | None: async def get_last_message_time(self, media_group_id: str) -> float | None:
raise NotImplementedError pass
@staticmethod @staticmethod
def deduplicate_messages(messages: list[Message]) -> list[Message]: def deduplicate_messages(messages: list[Message]) -> list[Message]:
@ -53,8 +54,9 @@ class BaseMediaGroupAggregator(ABC):
message_ids.add(message.message_id) message_ids.add(message.message_id)
return result return result
@abstractmethod
async def get_current_time(self) -> float: async def get_current_time(self) -> float:
return time.time() pass
class RedisMediaGroupAggregator(BaseMediaGroupAggregator): class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
@ -98,16 +100,24 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
res = await pipe.execute() res = await pipe.execute()
return cast(int, res[1]) return cast(int, res[1])
async def acquire_lock(self, media_group_id: str) -> bool: async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
return cast( return cast(
bool, bool,
await self.redis.set( await self.redis.set(
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec
), ),
) )
async def release_lock(self, media_group_id: str) -> None: async def release_lock(self, media_group_id: str, lock_id: str) -> None:
await self.redis.delete(self.get_group_lock_key(media_group_id)) release_script = (
'if redis.call("get", KEYS[1]) == ARGV[1] then '
'return redis.call("del", KEYS[1]) '
"else return 0 end"
)
await cast(
Awaitable[str],
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
)
async def get_group(self, media_group_id: str) -> list[Message]: async def get_group(self, media_group_id: str) -> list[Message]:
result = await cast( result = await cast(
@ -136,12 +146,12 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None: def __init__(self, ttl_sec: int = TTL_SEC) -> None:
self.groups: dict[str, list[Message]] = defaultdict(list) self.groups: dict[str, list[Message]] = defaultdict(list)
self.last_message_timers: dict[str, float] = {} self.last_message_timers: dict[str, float] = {}
self.locks: dict[str, bool] = {} self.locks: dict[str, str] = {}
self.ttl_sec = ttl_sec self.ttl_sec = ttl_sec
def remove_expired_objects(self) -> None: def remove_expired_objects(self) -> None:
expired_group_ids = [] expired_group_ids = []
current_time = time.time() current_time = time.monotonic()
for group_id, last_message_time in self.last_message_timers.items(): for group_id, last_message_time in self.last_message_timers.items():
if current_time - last_message_time > self.ttl_sec: if current_time - last_message_time > self.ttl_sec:
expired_group_ids.append(group_id) expired_group_ids.append(group_id)
@ -158,17 +168,18 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]): if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
self.groups[media_group_id].append(media) self.groups[media_group_id].append(media)
self.last_message_timers.pop(media_group_id, None) self.last_message_timers.pop(media_group_id, None)
self.last_message_timers[media_group_id] = time.time() self.last_message_timers[media_group_id] = time.monotonic()
return len(self.groups[media_group_id]) return len(self.groups[media_group_id])
async def acquire_lock(self, media_group_id: str) -> bool: async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
if self.locks.get(media_group_id): if self.locks.get(media_group_id):
return False return False
self.locks[media_group_id] = True self.locks[media_group_id] = lock_id
return True return True
async def release_lock(self, media_group_id: str) -> None: async def release_lock(self, media_group_id: str, lock_id: str) -> None:
self.locks.pop(media_group_id, None) if self.locks.get(media_group_id) == lock_id:
self.locks.pop(media_group_id)
async def get_group(self, media_group_id: str) -> list[Message]: async def get_group(self, media_group_id: str) -> list[Message]:
return self.groups.get(media_group_id, []) return self.groups.get(media_group_id, [])
@ -180,6 +191,9 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
async def get_last_message_time(self, media_group_id: str) -> float | None: async def get_last_message_time(self, media_group_id: str) -> float | None:
return self.last_message_timers.get(media_group_id) return self.last_message_timers.get(media_group_id)
async def get_current_time(self) -> float:
return time.monotonic()
class MediaGroupAggregatorMiddleware(BaseMiddleware): class MediaGroupAggregatorMiddleware(BaseMiddleware):
def __init__( def __init__(
@ -202,7 +216,8 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
if not isinstance(event, Message) or not event.media_group_id: if not isinstance(event, Message) or not event.media_group_id:
return await handler(event, data) return await handler(event, data)
await self.media_group_aggregator.add_into_group(event.media_group_id, event) await self.media_group_aggregator.add_into_group(event.media_group_id, event)
if not await self.media_group_aggregator.acquire_lock(event.media_group_id): lock_id = str(uuid.uuid4())
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
return None return None
try: try:
last_message_time = await self.media_group_aggregator.get_current_time() last_message_time = await self.media_group_aggregator.get_current_time()
@ -230,4 +245,4 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
return None return None
last_message_time = new_last_message_time last_message_time = new_last_message_time
finally: finally:
await self.media_group_aggregator.release_lock(event.media_group_id) await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)

View file

@ -189,15 +189,18 @@ class TestMediaGroupAggregator:
assert await aggregator.get_group("42") == [] assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator): async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
for _ in range(2): await aggregator.acquire_lock("42", "key1")
assert await aggregator.acquire_lock("42") assert not await aggregator.acquire_lock("42", "key2")
assert not await aggregator.acquire_lock("42") await aggregator.release_lock("42", "key1")
await aggregator.release_lock("42") for i in ("key2", "key3"):
assert await aggregator.acquire_lock("42", i)
assert not await aggregator.acquire_lock("42", i)
await aggregator.release_lock("42", i)
async def test_expired_objects_removed(self): async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator() aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1)) await aggregator.add_into_group("42", _get_message(1))
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1): with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
new_msg = _get_message(2) new_msg = _get_message(2)
await aggregator.add_into_group("24", new_msg) await aggregator.add_into_group("24", new_msg)
assert await aggregator.get_group("42") == [] assert await aggregator.get_group("42") == []
@ -205,7 +208,7 @@ class TestMediaGroupAggregator:
async def test_get_current_time_memory_aggregator(self): async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator() aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.time", return_value=1.1): with mock.patch("time.monotonic", return_value=1.1):
assert await aggregator.get_current_time() == 1.1 assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self): async def test_get_current_time_redis_aggregator(self):
@ -216,7 +219,7 @@ class TestMediaGroupAggregator:
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1)) await aggregator.add_into_group("42", _get_message(1))
time_before_second_message = time.time() time_before_second_message = await aggregator.get_current_time()
assert await aggregator.get_last_message_time("42") <= time_before_second_message assert await aggregator.get_last_message_time("42") <= time_before_second_message
await aggregator.add_into_group("42", _get_message(2)) await aggregator.add_into_group("42", _get_message(2))
assert await aggregator.get_last_message_time("42") >= time_before_second_message assert await aggregator.get_last_message_time("42") >= time_before_second_message