Add getting user from chat_boost (#1474)

* Add getting user from `chat_boost`

* Update import

* Add changelog

* Add test for `resolve_event_context`

* Lint changes

* Parametrize test
This commit is contained in:
Desiders 2024-08-14 02:12:39 +03:00 committed by GitHub
parent e2e1bc5573
commit 1c323ecc97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 63 additions and 2 deletions

1
CHANGES/1474.feature.rst Normal file
View file

@ -0,0 +1 @@
Added getting user from `chat_boost` with source `ChatBoostSourcePremium` in `UserContextMiddleware` for `EventContext`

View file

@ -2,7 +2,14 @@ from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
from aiogram.types import (
Chat,
ChatBoostSourcePremium,
InaccessibleMessage,
TelegramObject,
Update,
User,
)
EVENT_CONTEXT_KEY = "event_context"
@ -125,6 +132,14 @@ class UserContextMiddleware(BaseMiddleware):
if event.message_reaction_count:
return EventContext(chat=event.message_reaction_count.chat)
if event.chat_boost:
# We only check the premium source, because only it has a sender user,
# other sources have a user, but it is not the sender, but the recipient
if isinstance(event.chat_boost.boost.source, ChatBoostSourcePremium):
return EventContext(
chat=event.chat_boost.chat,
user=event.chat_boost.boost.source.user,
)
return EventContext(chat=event.chat_boost.chat)
if event.removed_chat_boost:
return EventContext(chat=event.removed_chat_boost.chat)

View file

@ -1,4 +1,5 @@
from unittest.mock import patch
from datetime import datetime
import pytest
@ -6,7 +7,16 @@ from aiogram.dispatcher.middlewares.user_context import (
EventContext,
UserContextMiddleware,
)
from aiogram.types import Chat, Update, User
from aiogram.types import (
Chat,
Update,
User,
ChatBoostUpdated,
ChatBoost,
ChatBoostSourcePremium,
ChatBoostSourceGiftCode,
ChatBoostSourceGiveaway,
)
async def next_handler(*args, **kwargs):
@ -41,3 +51,38 @@ class TestUserContextMiddleware:
assert data["event_chat"] is chat
assert data["event_from_user"] is user
assert data["event_thread_id"] == thread_id
@pytest.mark.parametrize(
"source, expected_user",
[
(
ChatBoostSourcePremium(user=User(id=2, first_name="Test", is_bot=False)),
User(id=2, first_name="Test", is_bot=False),
),
(ChatBoostSourceGiftCode(user=User(id=2, first_name="Test", is_bot=False)), None),
(
ChatBoostSourceGiveaway(
giveaway_message_id=1, user=User(id=2, first_name="Test", is_bot=False)
),
None,
),
],
)
async def test_resolve_event_context(self, source, expected_user):
middleware = UserContextMiddleware()
data = {}
chat = Chat(id=1, type="private", title="Test")
add_date = datetime.now()
expiration_date = datetime.now()
boost = ChatBoost(
boost_id="Test", add_date=add_date, expiration_date=expiration_date, source=source
)
update = Update(update_id=42, chat_boost=ChatBoostUpdated(chat=chat, boost=boost))
await middleware(next_handler, update, data)
event_context = data["event_context"]
assert isinstance(event_context, EventContext)
assert event_context.user == expected_user