mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge 28626d124a into e4d3692ac2
This commit is contained in:
commit
17539c777a
4 changed files with 533 additions and 4 deletions
1
CHANGES/1697.feature.rst
Normal file
1
CHANGES/1697.feature.rst
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
Added ``MediaGroupAggregatorMiddleware`` to aggregate media groups into a single event with an album list in the handler data.
|
||||||
246
aiogram/dispatcher/middlewares/media_group.py
Normal file
246
aiogram/dispatcher/middlewares/media_group.py
Normal file
|
|
@ -0,0 +1,246 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
|
from aiogram.types import Message, TelegramObject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio.client import Redis
|
||||||
|
|
||||||
|
DELAY_SEC = 1.0
|
||||||
|
LOCK_TTL_SEC = 30
|
||||||
|
TTL_SEC = 600
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMediaGroupAggregator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_group(self, media_group_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deduplicate_messages(messages: list[Message]) -> list[Message]:
|
||||||
|
message_ids = set()
|
||||||
|
result = []
|
||||||
|
for message in messages:
|
||||||
|
if message.message_id in message_ids:
|
||||||
|
continue
|
||||||
|
result.append(message)
|
||||||
|
message_ids.add(message.message_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_current_time(self) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
|
"""
|
||||||
|
Aggregates media groups in Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
redis: "Redis"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param ttl_sec: ttl for media group data in seconds
|
||||||
|
:param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the
|
||||||
|
lock from expiring before the handler finishes, but small enough to expire before
|
||||||
|
Telegram retries a failed delivery.
|
||||||
|
"""
|
||||||
|
self.redis = redis
|
||||||
|
self.ttl_sec = ttl_sec
|
||||||
|
self.lock_ttl_sec = lock_ttl_sec
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_group_key(media_group_id: str) -> str:
|
||||||
|
return f"media_group:{media_group_id}:album"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_last_message_time_key(media_group_id: str) -> str:
|
||||||
|
return f"media_group:{media_group_id}:last_message_time"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_group_lock_key(media_group_id: str) -> str:
|
||||||
|
return f"media_group:{media_group_id}:lock"
|
||||||
|
|
||||||
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||||
|
current_time = await self.get_current_time()
|
||||||
|
async with self.redis.pipeline(transaction=True) as pipe:
|
||||||
|
pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec)
|
||||||
|
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
||||||
|
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
||||||
|
res = await pipe.execute()
|
||||||
|
return cast(int, res[1])
|
||||||
|
|
||||||
|
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||||
|
return cast(
|
||||||
|
bool,
|
||||||
|
await self.redis.set(
|
||||||
|
self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||||
|
release_script = (
|
||||||
|
'if redis.call("get", KEYS[1]) == ARGV[1] then '
|
||||||
|
'return redis.call("del", KEYS[1]) '
|
||||||
|
"else return 0 end"
|
||||||
|
)
|
||||||
|
await cast(
|
||||||
|
Awaitable[int],
|
||||||
|
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||||
|
result = await cast(
|
||||||
|
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||||
|
)
|
||||||
|
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
|
||||||
|
|
||||||
|
async def delete_group(self, media_group_id: str) -> None:
|
||||||
|
async with self.redis.pipeline(transaction=True) as pipe:
|
||||||
|
pipe.delete(self.get_group_key(media_group_id))
|
||||||
|
pipe.delete(self.get_last_message_time_key(media_group_id))
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
||||||
|
result = await self.redis.get(self.get_last_message_time_key(media_group_id))
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return float(result)
|
||||||
|
|
||||||
|
async def get_current_time(self) -> float:
|
||||||
|
seconds, microseconds = cast(tuple[int, int], await self.redis.time())
|
||||||
|
return seconds + microseconds / 1e6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||||
|
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||||
|
self.groups: dict[str, list[Message]] = defaultdict(list)
|
||||||
|
self.last_message_timers: dict[str, float] = {}
|
||||||
|
self.locks: dict[str, str] = {}
|
||||||
|
self.ttl_sec = ttl_sec
|
||||||
|
|
||||||
|
def remove_expired_objects(self) -> None:
|
||||||
|
expired_group_ids = []
|
||||||
|
current_time = time.monotonic()
|
||||||
|
for group_id, last_message_time in self.last_message_timers.items():
|
||||||
|
if current_time - last_message_time > self.ttl_sec:
|
||||||
|
expired_group_ids.append(group_id)
|
||||||
|
else:
|
||||||
|
break # the list is sorted in ascending order
|
||||||
|
# because python 3.7+ save dict in insertion order
|
||||||
|
for group_id in expired_group_ids:
|
||||||
|
self.groups.pop(group_id, None)
|
||||||
|
self.last_message_timers.pop(group_id, None)
|
||||||
|
self.locks.pop(group_id, None)
|
||||||
|
|
||||||
|
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||||
|
self.remove_expired_objects()
|
||||||
|
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
|
||||||
|
self.groups[media_group_id].append(media)
|
||||||
|
self.last_message_timers.pop(media_group_id, None)
|
||||||
|
self.last_message_timers[media_group_id] = time.monotonic()
|
||||||
|
return len(self.groups[media_group_id])
|
||||||
|
|
||||||
|
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||||
|
if self.locks.get(media_group_id) is not None:
|
||||||
|
return False
|
||||||
|
self.locks[media_group_id] = lock_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
|
||||||
|
if self.locks.get(media_group_id) == lock_id:
|
||||||
|
self.locks.pop(media_group_id)
|
||||||
|
|
||||||
|
async def get_group(self, media_group_id: str) -> list[Message]:
|
||||||
|
return self.groups.get(media_group_id, [])
|
||||||
|
|
||||||
|
async def delete_group(self, media_group_id: str) -> None:
|
||||||
|
self.groups.pop(media_group_id, None)
|
||||||
|
self.last_message_timers.pop(media_group_id, None)
|
||||||
|
|
||||||
|
async def get_last_message_time(self, media_group_id: str) -> float | None:
|
||||||
|
return self.last_message_timers.get(media_group_id)
|
||||||
|
|
||||||
|
async def get_current_time(self) -> float:
|
||||||
|
return time.monotonic()
|
||||||
|
|
||||||
|
|
||||||
|
class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
media_group_aggregator: BaseMediaGroupAggregator | None = None,
|
||||||
|
delay: float = DELAY_SEC,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param delay: delay between last received message in media group and processing it
|
||||||
|
"""
|
||||||
|
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
|
||||||
|
self.delay = delay
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
|
event: TelegramObject,
|
||||||
|
data: dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
if not isinstance(event, Message) or not event.media_group_id:
|
||||||
|
return await handler(event, data)
|
||||||
|
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
|
||||||
|
lock_id = str(uuid.uuid4())
|
||||||
|
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
last_message_time = await self.media_group_aggregator.get_last_message_time(
|
||||||
|
event.media_group_id
|
||||||
|
)
|
||||||
|
if not last_message_time:
|
||||||
|
return None
|
||||||
|
delta = self.delay - (
|
||||||
|
await self.media_group_aggregator.get_current_time() - last_message_time
|
||||||
|
)
|
||||||
|
if delta <= 0:
|
||||||
|
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
||||||
|
if not album:
|
||||||
|
return None
|
||||||
|
album = sorted(
|
||||||
|
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
|
||||||
|
key=lambda msg: msg.message_id,
|
||||||
|
)
|
||||||
|
data.update(album=album)
|
||||||
|
result = await handler(album[0], data)
|
||||||
|
await self.media_group_aggregator.delete_group(event.media_group_id)
|
||||||
|
return result
|
||||||
|
await asyncio.sleep(delta)
|
||||||
|
finally:
|
||||||
|
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)
|
||||||
|
|
@ -1,8 +1,13 @@
|
||||||
===================
|
===========
|
||||||
Media group builder
|
Media group
|
||||||
===================
|
===========
|
||||||
|
|
||||||
This module provides a builder for media groups, it can be used to build media groups
|
This module provides tools for media groups.
|
||||||
|
|
||||||
|
Building media groups
|
||||||
|
=====================
|
||||||
|
|
||||||
|
Media group builder can be used to build media groups
|
||||||
for :class:`aiogram.types.input_media_photo.InputMediaPhoto`, :class:`aiogram.types.input_media_video.InputMediaVideo`,
|
for :class:`aiogram.types.input_media_photo.InputMediaPhoto`, :class:`aiogram.types.input_media_video.InputMediaVideo`,
|
||||||
:class:`aiogram.types.input_media_document.InputMediaDocument` and :class:`aiogram.types.input_media_audio.InputMediaAudio`.
|
:class:`aiogram.types.input_media_document.InputMediaDocument` and :class:`aiogram.types.input_media_audio.InputMediaAudio`.
|
||||||
|
|
||||||
|
|
@ -39,8 +44,58 @@ it will be used as ``caption`` for first media in group.
|
||||||
await bot.send_media_group(chat_id=chat_id, media=media_group.build())
|
await bot.send_media_group(chat_id=chat_id, media=media_group.build())
|
||||||
|
|
||||||
|
|
||||||
|
Handling media groups
|
||||||
|
=====================
|
||||||
|
|
||||||
|
By default each media in the group is processed separately.
|
||||||
|
|
||||||
|
You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware`
|
||||||
|
to process media groups as one. If you do, only one message from the group will be processed, and updates for
|
||||||
|
other messages with the same media group ID will be suppressed. There are two options to store media groups:
|
||||||
|
|
||||||
|
- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default
|
||||||
|
- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment
|
||||||
|
|
||||||
|
You also can use :class:`aiogram.filters.magic_data.MagicData` with ``F.album``
|
||||||
|
to filter media groups.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from aiogram import F
|
||||||
|
from aiogram.types import Message
|
||||||
|
|
||||||
|
# register middleware
|
||||||
|
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
|
||||||
|
from aiogram.filters import MagicData
|
||||||
|
|
||||||
|
router.message.outer_middleware(MediaGroupAggregatorMiddleware())
|
||||||
|
|
||||||
|
# use middleware
|
||||||
|
@router.message(
|
||||||
|
MagicData(F.album.len() <= 5),
|
||||||
|
F.caption == "album_caption" # other filters will be applied to the first message in the group
|
||||||
|
)
|
||||||
|
async def start(message: Message, album: list[Message]):
|
||||||
|
# message is the first media in this group
|
||||||
|
# album is list of all messages with the same mediaGroupId, including current message
|
||||||
|
await message.answer(
|
||||||
|
f"You sent {len(album)} media in the group. "
|
||||||
|
f"Media group ID: {message.media_group_id}. "
|
||||||
|
f"Album messages: {', '.join(str(m.message_id) for m in album)}"
|
||||||
|
)
|
||||||
|
|
||||||
References
|
References
|
||||||
==========
|
==========
|
||||||
|
|
||||||
.. autoclass:: aiogram.utils.media_group.MediaGroupBuilder
|
.. autoclass:: aiogram.utils.media_group.MediaGroupBuilder
|
||||||
:members:
|
:members:
|
||||||
|
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware
|
||||||
|
:members:
|
||||||
|
.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator
|
||||||
|
:members:
|
||||||
|
.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator
|
||||||
|
:members:
|
||||||
|
:special-members: __init__
|
||||||
|
|
|
||||||
227
tests/test_dispatcher/test_middlewares/test_media_group.py
Normal file
227
tests/test_dispatcher/test_middlewares/test_media_group.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from redis.asyncio.client import Redis
|
||||||
|
|
||||||
|
from aiogram.dispatcher.middlewares.media_group import (
|
||||||
|
BaseMediaGroupAggregator,
|
||||||
|
MediaGroupAggregatorMiddleware,
|
||||||
|
MemoryMediaGroupAggregator,
|
||||||
|
RedisMediaGroupAggregator,
|
||||||
|
)
|
||||||
|
from aiogram.types import Chat, Message
|
||||||
|
|
||||||
|
|
||||||
|
def _get_message(message_id: int, **kwargs):
|
||||||
|
chat = Chat(id=1, type="private", title="Test")
|
||||||
|
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
|
||||||
|
start_sleep = asyncio.Event()
|
||||||
|
real_sleep = asyncio.sleep
|
||||||
|
|
||||||
|
async def mock_sleep(*args, **kwargs):
|
||||||
|
start_sleep.set()
|
||||||
|
await real_sleep(0)
|
||||||
|
|
||||||
|
with mock.patch("asyncio.sleep", mock_sleep):
|
||||||
|
task1 = func(*args, **kwargs)
|
||||||
|
await start_sleep.wait()
|
||||||
|
return task1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMediaGroupAggregatorMiddleware:
|
||||||
|
def get_middleware(self):
|
||||||
|
return MediaGroupAggregatorMiddleware(delay=0.1)
|
||||||
|
|
||||||
|
async def test_skip_non_media_group(self):
|
||||||
|
is_called = False
|
||||||
|
|
||||||
|
async def next_handler(*args, **kwargs):
|
||||||
|
nonlocal is_called
|
||||||
|
is_called = True
|
||||||
|
|
||||||
|
await self.get_middleware()(next_handler, _get_message(1), {})
|
||||||
|
assert is_called
|
||||||
|
|
||||||
|
async def test_called_once_for_album(self):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
counter = 0
|
||||||
|
album = None
|
||||||
|
|
||||||
|
async def next_handler(_, data: dict[str, Any]):
|
||||||
|
nonlocal counter, album
|
||||||
|
counter += 1
|
||||||
|
album = data.get("album")
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
|
||||||
|
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
|
||||||
|
)
|
||||||
|
assert album is not None
|
||||||
|
assert len(album) == 2
|
||||||
|
assert counter == 1
|
||||||
|
|
||||||
|
async def test_bot_object_saved(self, bot):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
event = album = None
|
||||||
|
|
||||||
|
async def next_handler(message: Message, data: dict[str, Any]):
|
||||||
|
nonlocal event, album
|
||||||
|
event = message
|
||||||
|
album = data.get("album")
|
||||||
|
|
||||||
|
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
|
||||||
|
assert event.bot is bot
|
||||||
|
assert all(msg.bot is bot for msg in album)
|
||||||
|
|
||||||
|
async def test_propagate_first_media_in_album(self):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
first_message = None
|
||||||
|
|
||||||
|
async def next_handler(message: Message, _):
|
||||||
|
nonlocal first_message
|
||||||
|
first_message = message
|
||||||
|
|
||||||
|
task1 = await wait_until_func_call_sleep(
|
||||||
|
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
|
||||||
|
)
|
||||||
|
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||||
|
await task1
|
||||||
|
assert isinstance(first_message, Message)
|
||||||
|
assert first_message.message_id == 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("deleted_object", ["album", "last_message_time"])
|
||||||
|
async def test_skip_propagating_if_data_deleted(self, deleted_object):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
async def next_handler(*args, **kwargs):
|
||||||
|
nonlocal counter
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
task1 = await wait_until_func_call_sleep(
|
||||||
|
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||||
|
)
|
||||||
|
if deleted_object == "album":
|
||||||
|
middleware.media_group_aggregator.groups.pop("42")
|
||||||
|
else:
|
||||||
|
middleware.media_group_aggregator.last_message_timers.pop("42")
|
||||||
|
await task1
|
||||||
|
assert counter == 0
|
||||||
|
|
||||||
|
async def test_different_albums_non_interfere(self):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
counter = 0
|
||||||
|
albums = []
|
||||||
|
|
||||||
|
async def next_handler(_, data: dict[str, Any]):
|
||||||
|
nonlocal counter, albums
|
||||||
|
counter += 1
|
||||||
|
albums.append(data.get("album"))
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
|
||||||
|
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
|
||||||
|
)
|
||||||
|
assert counter == 2
|
||||||
|
assert len(albums) == 2
|
||||||
|
|
||||||
|
async def test_retry_handling(self):
|
||||||
|
middleware = self.get_middleware()
|
||||||
|
album = None
|
||||||
|
|
||||||
|
async def failed_handler(*args, **kwargs):
|
||||||
|
raise RuntimeError("Failed")
|
||||||
|
|
||||||
|
async def working_handler(_, data: dict[str, Any]):
|
||||||
|
nonlocal album
|
||||||
|
album = data.get("album")
|
||||||
|
|
||||||
|
first_message = _get_message(1, media_group_id="42")
|
||||||
|
second_message = _get_message(2, media_group_id="42")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await asyncio.gather(
|
||||||
|
middleware(failed_handler, first_message, {}),
|
||||||
|
middleware(failed_handler, second_message, {}),
|
||||||
|
)
|
||||||
|
await middleware(working_handler, first_message, {})
|
||||||
|
assert len(album) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_deduplication():
|
||||||
|
message_1, message_2 = _get_message(1), _get_message(2)
|
||||||
|
res = [message_1, message_2]
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
|
||||||
|
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["memory", "redis"], scope="function")
|
||||||
|
async def aggregator(request):
|
||||||
|
if request.param == "memory":
|
||||||
|
yield MemoryMediaGroupAggregator()
|
||||||
|
else:
|
||||||
|
redis = Redis.from_url(request.getfixturevalue("redis_server"))
|
||||||
|
yield RedisMediaGroupAggregator(redis)
|
||||||
|
keys = await redis.keys("media_group:*")
|
||||||
|
if keys:
|
||||||
|
await redis.delete(*keys)
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMediaGroupAggregator:
|
||||||
|
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
msg1 = _get_message(1)
|
||||||
|
msg2 = _get_message(2)
|
||||||
|
assert await aggregator.add_into_group("42", msg1) == 1
|
||||||
|
assert await aggregator.add_into_group("42", msg2) == 2
|
||||||
|
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
|
||||||
|
msg1.message_id,
|
||||||
|
msg2.message_id,
|
||||||
|
}
|
||||||
|
await aggregator.delete_group("42")
|
||||||
|
assert await aggregator.get_group("42") == []
|
||||||
|
|
||||||
|
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
for i in ("key1", "key2"):
|
||||||
|
assert await aggregator.acquire_lock("42", i)
|
||||||
|
assert not await aggregator.acquire_lock("42", i)
|
||||||
|
await aggregator.release_lock("42", i)
|
||||||
|
|
||||||
|
async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
await aggregator.acquire_lock("42", "key1")
|
||||||
|
await aggregator.release_lock("42", "key2")
|
||||||
|
assert not await aggregator.acquire_lock("42", "key1")
|
||||||
|
|
||||||
|
async def test_expired_objects_removed(self):
|
||||||
|
aggregator = MemoryMediaGroupAggregator()
|
||||||
|
await aggregator.add_into_group("42", _get_message(1))
|
||||||
|
with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
|
||||||
|
new_msg = _get_message(2)
|
||||||
|
await aggregator.add_into_group("24", new_msg)
|
||||||
|
assert await aggregator.get_group("42") == []
|
||||||
|
assert await aggregator.get_group("24") == [new_msg]
|
||||||
|
|
||||||
|
async def test_get_current_time_memory_aggregator(self):
|
||||||
|
aggregator = MemoryMediaGroupAggregator()
|
||||||
|
with mock.patch("time.monotonic", return_value=1.1):
|
||||||
|
assert await aggregator.get_current_time() == 1.1
|
||||||
|
|
||||||
|
async def test_get_current_time_redis_aggregator(self):
|
||||||
|
aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis))
|
||||||
|
aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456))
|
||||||
|
assert await aggregator.get_current_time() == 1.123456
|
||||||
|
|
||||||
|
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||||
|
assert await aggregator.get_last_message_time("42") is None
|
||||||
|
await aggregator.add_into_group("42", _get_message(1))
|
||||||
|
time_before_second_message = await aggregator.get_current_time()
|
||||||
|
assert await aggregator.get_last_message_time("42") <= time_before_second_message
|
||||||
|
await aggregator.add_into_group("42", _get_message(2))
|
||||||
|
assert await aggregator.get_last_message_time("42") >= time_before_second_message
|
||||||
Loading…
Add table
Add a link
Reference in a new issue