feat: add MediaGroupFilter

This commit is contained in:
Vitaly312 2026-02-28 16:43:54 +03:00
parent 998d6c3742
commit a8bd68eb35
3 changed files with 137 additions and 5 deletions

View file

@ -0,0 +1,60 @@
from typing import Any, Literal
from aiogram.filters.base import Filter
from aiogram.types import Message
MIN_MEDIA_COUNT = 2
DEFAULT_MAX_MEDIA_COUNT = 10
class MediaGroupFilter(Filter):
"""
This filter helps to handle media groups.
Works only with :class:`aiogram.types.message.Message` events which have the :code:`album`
in the handler context.
"""
__slots__ = ("min_media_count", "max_media_count")
def __init__(
self,
count: int | None = None,
min_media_count: int | None = None,
max_media_count: int | None = None,
):
"""
:param count: expected count of media in the group.
:param min_media_count: min count of media in the group, inclusively
:param max_media_count: max count of media in the group, inclusively
"""
if count is None:
min_media_count = min_media_count or MIN_MEDIA_COUNT
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT
else:
if min_media_count is not None or max_media_count is not None:
raise ValueError(
"count and min_media_count or max_media_count can not be used together"
)
if count < MIN_MEDIA_COUNT:
raise ValueError(f"count should be greater or equal to {MIN_MEDIA_COUNT}")
min_media_count = max_media_count = count
if min_media_count < MIN_MEDIA_COUNT:
raise ValueError(f"min_media_count should be greater or equal to {MIN_MEDIA_COUNT}")
if max_media_count < min_media_count:
raise ValueError("max_media_count should be greater or equal to min_media_count")
self.min_media_count = min_media_count
self.max_media_count = max_media_count
def __str__(self) -> str:
return self._signature_to_string(
min_media_count=self.min_media_count, max_media_count=self.max_media_count
)
async def __call__(
self, message: Message, album: list[Message] = None
) -> Literal[False] | dict[str, Any]:
media_count = len(album or [])
if not (self.min_media_count <= media_count <= self.max_media_count):
return False
return {"media_count": media_count}

View file

@ -6,20 +6,22 @@ from datetime import datetime
from typing import Any
import pytest
class TestMediaGroupAggregatorMiddleware:
def _get_message(self, message_id: int, **kwargs):
chat = Chat(id=1, type="private", title="Test")
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
def get_middleware(self):
return MediaGroupAggregatorMiddleware(delay=0.1)
async def test_skip_non_media_group(self):
is_called = False
async def next_handler(*args, **kwargs):
nonlocal is_called
is_called = True
await self.get_middleware()(next_handler, self._get_message(1), {})
assert is_called
@ -27,13 +29,15 @@ class TestMediaGroupAggregatorMiddleware:
middleware = self.get_middleware()
counter = 0
album = None
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, album
counter += 1
album = data.get("album")
await asyncio.gather(
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
middleware(next_handler, self._get_message(2, media_group_id="42"), {})
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
)
assert album is not None
assert len(album) == 2
@ -42,12 +46,14 @@ class TestMediaGroupAggregatorMiddleware:
async def test_propagate_first_media_in_album(self):
middleware = self.get_middleware()
first_message = None
async def next_handler(message: Message, _):
nonlocal first_message
first_message = message
await asyncio.gather(
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
middleware(next_handler, self._get_message(1, media_group_id="42"), {})
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
)
assert isinstance(first_message, Message)
assert first_message.message_id == 1
@ -56,13 +62,15 @@ class TestMediaGroupAggregatorMiddleware:
middleware = self.get_middleware()
counter = 0
albums = []
async def next_handler(_, data: dict[str, Any]):
nonlocal counter, albums
counter += 1
albums.append(data.get("album"))
await asyncio.gather(
middleware(next_handler, self._get_message(1, media_group_id="1"), {}),
middleware(next_handler, self._get_message(2, media_group_id="2"), {})
middleware(next_handler, self._get_message(2, media_group_id="2"), {}),
)
assert counter == 2
assert len(albums) == 2
@ -70,17 +78,20 @@ class TestMediaGroupAggregatorMiddleware:
async def test_retry_handling(self):
middleware = self.get_middleware()
album = None
async def failed_handler(*args, **kwargs):
raise Exception("Failed")
async def working_handler(_, data: dict[str, Any]):
nonlocal album
album = data.get("album")
first_message = self._get_message(1, media_group_id="42")
second_message = self._get_message(2, media_group_id="42")
with pytest.raises(Exception):
await asyncio.gather(
middleware(failed_handler, first_message, {}),
middleware(failed_handler, second_message, {})
middleware(failed_handler, second_message, {}),
)
await middleware(working_handler, first_message, {})
assert len(album) == 2

View file

@ -0,0 +1,61 @@
from aiogram.filters.media_group import MediaGroupFilter, MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT
import pytest
import datetime
from aiogram.types import Message, Chat
class TestMediaGroupFilter:
@pytest.mark.parametrize(
"args,min_count,max_count",
[
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
((3,), 3, 3),
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
((None, None, 3), MIN_MEDIA_COUNT, 3),
],
)
def test_init_range(self, args, min_count, max_count):
filter = MediaGroupFilter(*args)
assert filter.max_media_count == max_count
assert filter.min_media_count == min_count
@pytest.mark.parametrize(
"count,min_count,max_count",
[
(1, None, 1),
(1, 1, None),
(None, 1, None),
(None, None, 1),
(1, None, None),
(None, 5, 3),
],
)
def test_raise_error(self, count, min_count, max_count):
with pytest.raises(ValueError):
MediaGroupFilter(count, min_count, max_count)
@pytest.mark.parametrize(
"min_count,max_count,media_count,result",
[
[2, 2, 1, False],
[2, 2, 2, True],
[2, 2, 3, False],
[2, 5, 2, True],
[2, 5, 5, True],
[2, 5, 6, False],
],
)
async def test_call(self, min_count, max_count, media_count, result):
filter = MediaGroupFilter(min_media_count=min_count, max_media_count=max_count)
album = [
Message(
message_id=i,
date=datetime.datetime.now(),
chat=Chat(id=42, type="private"),
)
for i in range(media_count)
]
response = await filter(album[0], album)
assert bool(response) is result
if result:
assert response.get("media_count") == media_count