Add getting user from chat_boost (#1474)

* Add getting user from `chat_boost`

* Update import

* Add changelog

* Add test for `resolve_event_context`

* Lint changes

* Parametrize test
This commit is contained in:
Desiders 2024-08-14 02:12:39 +03:00 committed by GitHub
parent e2e1bc5573
commit 1c323ecc97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 63 additions and 2 deletions

1
CHANGES/1474.feature.rst Normal file
View file

@ -0,0 +1 @@
Added getting user from `chat_boost` with source `ChatBoostSourcePremium` in `UserContextMiddleware` for `EventContext`

View file

@ -2,7 +2,14 @@ from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional from typing import Any, Awaitable, Callable, Dict, Optional
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User from aiogram.types import (
Chat,
ChatBoostSourcePremium,
InaccessibleMessage,
TelegramObject,
Update,
User,
)
EVENT_CONTEXT_KEY = "event_context" EVENT_CONTEXT_KEY = "event_context"
@ -125,6 +132,14 @@ class UserContextMiddleware(BaseMiddleware):
if event.message_reaction_count: if event.message_reaction_count:
return EventContext(chat=event.message_reaction_count.chat) return EventContext(chat=event.message_reaction_count.chat)
if event.chat_boost: if event.chat_boost:
# We only check the premium source, because only it has a sender user,
# other sources have a user, but it is not the sender, but the recipient
if isinstance(event.chat_boost.boost.source, ChatBoostSourcePremium):
return EventContext(
chat=event.chat_boost.chat,
user=event.chat_boost.boost.source.user,
)
return EventContext(chat=event.chat_boost.chat) return EventContext(chat=event.chat_boost.chat)
if event.removed_chat_boost: if event.removed_chat_boost:
return EventContext(chat=event.removed_chat_boost.chat) return EventContext(chat=event.removed_chat_boost.chat)

View file

@ -1,4 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
from datetime import datetime
import pytest import pytest
@ -6,7 +7,16 @@ from aiogram.dispatcher.middlewares.user_context import (
EventContext, EventContext,
UserContextMiddleware, UserContextMiddleware,
) )
from aiogram.types import Chat, Update, User from aiogram.types import (
Chat,
Update,
User,
ChatBoostUpdated,
ChatBoost,
ChatBoostSourcePremium,
ChatBoostSourceGiftCode,
ChatBoostSourceGiveaway,
)
async def next_handler(*args, **kwargs): async def next_handler(*args, **kwargs):
@ -41,3 +51,38 @@ class TestUserContextMiddleware:
assert data["event_chat"] is chat assert data["event_chat"] is chat
assert data["event_from_user"] is user assert data["event_from_user"] is user
assert data["event_thread_id"] == thread_id assert data["event_thread_id"] == thread_id
@pytest.mark.parametrize(
"source, expected_user",
[
(
ChatBoostSourcePremium(user=User(id=2, first_name="Test", is_bot=False)),
User(id=2, first_name="Test", is_bot=False),
),
(ChatBoostSourceGiftCode(user=User(id=2, first_name="Test", is_bot=False)), None),
(
ChatBoostSourceGiveaway(
giveaway_message_id=1, user=User(id=2, first_name="Test", is_bot=False)
),
None,
),
],
)
async def test_resolve_event_context(self, source, expected_user):
middleware = UserContextMiddleware()
data = {}
chat = Chat(id=1, type="private", title="Test")
add_date = datetime.now()
expiration_date = datetime.now()
boost = ChatBoost(
boost_id="Test", add_date=add_date, expiration_date=expiration_date, source=source
)
update = Update(update_id=42, chat_boost=ChatBoostUpdated(chat=chat, boost=boost))
await middleware(next_handler, update, data)
event_context = data["event_context"]
assert isinstance(event_context, EventContext)
assert event_context.user == expected_user