mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add getting user from chat_boost (#1474)
* Add getting user from `chat_boost` * Update import * Add changelog * Add test for `resolve_event_context` * Lint changes * Parametrize test
This commit is contained in:
parent
e2e1bc5573
commit
1c323ecc97
3 changed files with 63 additions and 2 deletions
1
CHANGES/1474.feature.rst
Normal file
1
CHANGES/1474.feature.rst
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
Added getting user from `chat_boost` with source `ChatBoostSourcePremium` in `UserContextMiddleware` for `EventContext`
|
||||||
|
|
@ -2,7 +2,14 @@ from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
|
from aiogram.types import (
|
||||||
|
Chat,
|
||||||
|
ChatBoostSourcePremium,
|
||||||
|
InaccessibleMessage,
|
||||||
|
TelegramObject,
|
||||||
|
Update,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
EVENT_CONTEXT_KEY = "event_context"
|
EVENT_CONTEXT_KEY = "event_context"
|
||||||
|
|
||||||
|
|
@ -125,6 +132,14 @@ class UserContextMiddleware(BaseMiddleware):
|
||||||
if event.message_reaction_count:
|
if event.message_reaction_count:
|
||||||
return EventContext(chat=event.message_reaction_count.chat)
|
return EventContext(chat=event.message_reaction_count.chat)
|
||||||
if event.chat_boost:
|
if event.chat_boost:
|
||||||
|
# We only check the premium source, because only it has a sender user,
|
||||||
|
# other sources have a user, but it is not the sender, but the recipient
|
||||||
|
if isinstance(event.chat_boost.boost.source, ChatBoostSourcePremium):
|
||||||
|
return EventContext(
|
||||||
|
chat=event.chat_boost.chat,
|
||||||
|
user=event.chat_boost.boost.source.user,
|
||||||
|
)
|
||||||
|
|
||||||
return EventContext(chat=event.chat_boost.chat)
|
return EventContext(chat=event.chat_boost.chat)
|
||||||
if event.removed_chat_boost:
|
if event.removed_chat_boost:
|
||||||
return EventContext(chat=event.removed_chat_boost.chat)
|
return EventContext(chat=event.removed_chat_boost.chat)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -6,7 +7,16 @@ from aiogram.dispatcher.middlewares.user_context import (
|
||||||
EventContext,
|
EventContext,
|
||||||
UserContextMiddleware,
|
UserContextMiddleware,
|
||||||
)
|
)
|
||||||
from aiogram.types import Chat, Update, User
|
from aiogram.types import (
|
||||||
|
Chat,
|
||||||
|
Update,
|
||||||
|
User,
|
||||||
|
ChatBoostUpdated,
|
||||||
|
ChatBoost,
|
||||||
|
ChatBoostSourcePremium,
|
||||||
|
ChatBoostSourceGiftCode,
|
||||||
|
ChatBoostSourceGiveaway,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def next_handler(*args, **kwargs):
|
async def next_handler(*args, **kwargs):
|
||||||
|
|
@ -41,3 +51,38 @@ class TestUserContextMiddleware:
|
||||||
assert data["event_chat"] is chat
|
assert data["event_chat"] is chat
|
||||||
assert data["event_from_user"] is user
|
assert data["event_from_user"] is user
|
||||||
assert data["event_thread_id"] == thread_id
|
assert data["event_thread_id"] == thread_id
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"source, expected_user",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
ChatBoostSourcePremium(user=User(id=2, first_name="Test", is_bot=False)),
|
||||||
|
User(id=2, first_name="Test", is_bot=False),
|
||||||
|
),
|
||||||
|
(ChatBoostSourceGiftCode(user=User(id=2, first_name="Test", is_bot=False)), None),
|
||||||
|
(
|
||||||
|
ChatBoostSourceGiveaway(
|
||||||
|
giveaway_message_id=1, user=User(id=2, first_name="Test", is_bot=False)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_resolve_event_context(self, source, expected_user):
|
||||||
|
middleware = UserContextMiddleware()
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
chat = Chat(id=1, type="private", title="Test")
|
||||||
|
add_date = datetime.now()
|
||||||
|
expiration_date = datetime.now()
|
||||||
|
|
||||||
|
boost = ChatBoost(
|
||||||
|
boost_id="Test", add_date=add_date, expiration_date=expiration_date, source=source
|
||||||
|
)
|
||||||
|
update = Update(update_id=42, chat_boost=ChatBoostUpdated(chat=chat, boost=boost))
|
||||||
|
|
||||||
|
await middleware(next_handler, update, data)
|
||||||
|
|
||||||
|
event_context = data["event_context"]
|
||||||
|
assert isinstance(event_context, EventContext)
|
||||||
|
assert event_context.user == expected_user
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue