mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 16:15:51 +00:00
Forum topic in FSM (#1161)
* Base implementation * Added tests, fixed arguments priority * Use `Optional[X]` instead of `X | None` * Added changelog * Added tests
This commit is contained in:
parent
1538bc2e2d
commit
942ba0d520
10 changed files with 164 additions and 60 deletions
17
CHANGES/1161.feature.rst
Normal file
17
CHANGES/1161.feature.rst
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
Added support for FSM in Forum topics.
|
||||
|
||||
The strategy can be changed in dispatcher:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from aiogram.fsm.strategy import FSMStrategy
|
||||
...
|
||||
dispatcher = Dispatcher(
|
||||
fsm_strategy=FSMStrategy.USER_IN_THREAD,
|
||||
storage=..., # Any persistent storage
|
||||
)
|
||||
|
||||
.. note::
|
||||
|
||||
If you have implemented you own storages you should extend record key generation
|
||||
with new one attribute - `thread_id`
|
||||
|
|
@ -4,6 +4,10 @@ from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
|
|||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.types import Chat, TelegramObject, Update, User
|
||||
|
||||
EVENT_FROM_USER_KEY = "event_from_user"
|
||||
EVENT_CHAT_KEY = "event_chat"
|
||||
EVENT_THREAD_ID_KEY = "event_thread_id"
|
||||
|
||||
|
||||
class UserContextMiddleware(BaseMiddleware):
|
||||
async def __call__(
|
||||
|
|
@ -14,61 +18,64 @@ class UserContextMiddleware(BaseMiddleware):
|
|||
) -> Any:
|
||||
if not isinstance(event, Update):
|
||||
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
|
||||
chat, user = self.resolve_event_context(event=event)
|
||||
with self.context(chat=chat, user=user):
|
||||
if user is not None:
|
||||
data["event_from_user"] = user
|
||||
if chat is not None:
|
||||
data["event_chat"] = chat
|
||||
return await handler(event, data)
|
||||
|
||||
@contextmanager
|
||||
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
|
||||
chat_token = None
|
||||
user_token = None
|
||||
if chat:
|
||||
chat_token = chat.set_current(chat)
|
||||
if user:
|
||||
user_token = user.set_current(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if chat and chat_token:
|
||||
chat.reset_current(chat_token)
|
||||
if user and user_token:
|
||||
user.reset_current(user_token)
|
||||
chat, user, thread_id = self.resolve_event_context(event=event)
|
||||
if user is not None:
|
||||
data[EVENT_FROM_USER_KEY] = user
|
||||
if chat is not None:
|
||||
data[EVENT_CHAT_KEY] = chat
|
||||
if thread_id is not None:
|
||||
data[EVENT_THREAD_ID_KEY] = thread_id
|
||||
return await handler(event, data)
|
||||
|
||||
@classmethod
|
||||
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
|
||||
def resolve_event_context(
|
||||
cls, event: Update
|
||||
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
|
||||
"""
|
||||
Resolve chat and user instance from Update object
|
||||
"""
|
||||
if event.message:
|
||||
return event.message.chat, event.message.from_user
|
||||
return (
|
||||
event.message.chat,
|
||||
event.message.from_user,
|
||||
event.message.message_thread_id if event.message.is_topic_message else None,
|
||||
)
|
||||
if event.edited_message:
|
||||
return event.edited_message.chat, event.edited_message.from_user
|
||||
return (
|
||||
event.edited_message.chat,
|
||||
event.edited_message.from_user,
|
||||
event.edited_message.message_thread_id
|
||||
if event.edited_message.is_topic_message
|
||||
else None,
|
||||
)
|
||||
if event.channel_post:
|
||||
return event.channel_post.chat, None
|
||||
return event.channel_post.chat, None, None
|
||||
if event.edited_channel_post:
|
||||
return event.edited_channel_post.chat, None
|
||||
return event.edited_channel_post.chat, None, None
|
||||
if event.inline_query:
|
||||
return None, event.inline_query.from_user
|
||||
return None, event.inline_query.from_user, None
|
||||
if event.chosen_inline_result:
|
||||
return None, event.chosen_inline_result.from_user
|
||||
return None, event.chosen_inline_result.from_user, None
|
||||
if event.callback_query:
|
||||
if event.callback_query.message:
|
||||
return event.callback_query.message.chat, event.callback_query.from_user
|
||||
return None, event.callback_query.from_user
|
||||
return (
|
||||
event.callback_query.message.chat,
|
||||
event.callback_query.from_user,
|
||||
event.callback_query.message.message_thread_id
|
||||
if event.callback_query.message.is_topic_message
|
||||
else None,
|
||||
)
|
||||
return None, event.callback_query.from_user, None
|
||||
if event.shipping_query:
|
||||
return None, event.shipping_query.from_user
|
||||
return None, event.shipping_query.from_user, None
|
||||
if event.pre_checkout_query:
|
||||
return None, event.pre_checkout_query.from_user
|
||||
return None, event.pre_checkout_query.from_user, None
|
||||
if event.poll_answer:
|
||||
return None, event.poll_answer.user
|
||||
return None, event.poll_answer.user, None
|
||||
if event.my_chat_member:
|
||||
return event.my_chat_member.chat, event.my_chat_member.from_user
|
||||
return event.my_chat_member.chat, event.my_chat_member.from_user, None
|
||||
if event.chat_member:
|
||||
return event.chat_member.chat, event.chat_member.from_user
|
||||
return event.chat_member.chat, event.chat_member.from_user, None
|
||||
if event.chat_join_request:
|
||||
return event.chat_join_request.chat, event.chat_join_request.from_user
|
||||
return None, None
|
||||
return event.chat_join_request.chat, event.chat_join_request.from_user, None
|
||||
return None, None, None
|
||||
|
|
|
|||
|
|
@ -47,25 +47,42 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
) -> Optional[FSMContext]:
|
||||
user = data.get("event_from_user")
|
||||
chat = data.get("event_chat")
|
||||
thread_id = data.get("event_thread_id")
|
||||
chat_id = chat.id if chat else None
|
||||
user_id = user.id if user else None
|
||||
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
return self.resolve_context(
|
||||
bot=bot,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
destiny=destiny,
|
||||
)
|
||||
|
||||
def resolve_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: Optional[int],
|
||||
user_id: Optional[int],
|
||||
thread_id: Optional[int] = None,
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> Optional[FSMContext]:
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
|
||||
if chat_id is not None and user_id is not None:
|
||||
chat_id, user_id = apply_strategy(
|
||||
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
||||
chat_id, user_id, thread_id = apply_strategy(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
strategy=self.strategy,
|
||||
)
|
||||
return self.get_context(
|
||||
bot=bot,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
destiny=destiny,
|
||||
)
|
||||
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
return None
|
||||
|
||||
def get_context(
|
||||
|
|
@ -73,6 +90,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
thread_id: Optional[int] = None,
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> FSMContext:
|
||||
return FSMContext(
|
||||
|
|
@ -81,6 +99,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
bot_id=bot.id,
|
||||
thread_id=thread_id,
|
||||
destiny=destiny,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class StorageKey:
|
|||
bot_id: int
|
||||
chat_id: int
|
||||
user_id: int
|
||||
thread_id: Optional[int] = None
|
||||
destiny: str = DEFAULT_DESTINY
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
|
|||
parts = [self.prefix]
|
||||
if self.with_bot_id:
|
||||
parts.append(str(key.bot_id))
|
||||
parts.extend([str(key.chat_id), str(key.user_id)])
|
||||
parts.append(str(key.chat_id))
|
||||
if key.thread_id:
|
||||
parts.append(str(key.thread_id))
|
||||
parts.append(str(key.user_id))
|
||||
if self.with_destiny:
|
||||
parts.append(key.destiny)
|
||||
elif key.destiny != DEFAULT_DESTINY:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
from enum import Enum, auto
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class FSMStrategy(Enum):
|
||||
USER_IN_CHAT = auto()
|
||||
CHAT = auto()
|
||||
GLOBAL_USER = auto()
|
||||
USER_IN_THREAD = auto()
|
||||
|
||||
|
||||
def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]:
|
||||
def apply_strategy(
|
||||
strategy: FSMStrategy,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
thread_id: Optional[int] = None,
|
||||
) -> Tuple[int, int, Optional[int]]:
|
||||
if strategy == FSMStrategy.CHAT:
|
||||
return chat_id, chat_id
|
||||
return chat_id, chat_id, None
|
||||
if strategy == FSMStrategy.GLOBAL_USER:
|
||||
return user_id, user_id
|
||||
return chat_id, user_id
|
||||
return user_id, user_id, None
|
||||
if strategy == FSMStrategy.USER_IN_THREAD:
|
||||
return chat_id, user_id, thread_id
|
||||
return chat_id, user_id, None
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from aiogram import Bot
|
|||
from aiogram.dispatcher.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod
|
||||
from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod
|
||||
from aiogram.types import (
|
||||
CallbackQuery,
|
||||
Chat,
|
||||
|
|
@ -462,9 +462,9 @@ class TestDispatcher:
|
|||
async def my_handler(event: Any, **kwargs: Any):
|
||||
assert event == getattr(update, event_type)
|
||||
if has_chat:
|
||||
assert Chat.get_current(False)
|
||||
assert kwargs["event_chat"]
|
||||
if has_user:
|
||||
assert User.get_current(False)
|
||||
assert kwargs["event_from_user"]
|
||||
return kwargs
|
||||
|
||||
result = await router.feed_update(bot, update, test="PASS")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
|
||||
from aiogram.types import Update
|
||||
|
||||
|
||||
async def next_handler(*args, **kwargs):
|
||||
|
|
@ -11,3 +14,13 @@ class TestUserContextMiddleware:
|
|||
async def test_unexpected_event_type(self):
|
||||
with pytest.raises(RuntimeError):
|
||||
await UserContextMiddleware()(next_handler, object(), {})
|
||||
|
||||
async def test_call(self):
|
||||
middleware = UserContextMiddleware()
|
||||
data = {}
|
||||
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
|
||||
await middleware(next_handler, Update(update_id=42), data)
|
||||
|
||||
assert data["event_chat"] == 1
|
||||
assert data["event_from_user"] == 2
|
||||
assert data["event_thread_id"] == 3
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ PREFIX = "test"
|
|||
BOT_ID = 42
|
||||
CHAT_ID = -1
|
||||
USER_ID = 2
|
||||
THREAD_ID = 3
|
||||
FIELD = "data"
|
||||
|
||||
|
||||
|
|
@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder:
|
|||
with pytest.raises(ValueError):
|
||||
key_builder.build(key, FIELD)
|
||||
|
||||
def test_thread_id(self):
|
||||
key_builder = DefaultKeyBuilder(
|
||||
prefix=PREFIX,
|
||||
)
|
||||
key = StorageKey(
|
||||
chat_id=CHAT_ID,
|
||||
user_id=USER_ID,
|
||||
bot_id=BOT_ID,
|
||||
thread_id=THREAD_ID,
|
||||
destiny=DEFAULT_DESTINY,
|
||||
)
|
||||
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
|
||||
|
||||
def test_create_isolation(self):
|
||||
fake_redis = object()
|
||||
storage = RedisStorage(redis=fake_redis)
|
||||
|
|
|
|||
|
|
@ -2,19 +2,41 @@ import pytest
|
|||
|
||||
from aiogram.fsm.strategy import FSMStrategy, apply_strategy
|
||||
|
||||
CHAT_ID = -42
|
||||
USER_ID = 42
|
||||
THREAD_ID = 1
|
||||
|
||||
PRIVATE = (USER_ID, USER_ID, None)
|
||||
CHAT = (CHAT_ID, USER_ID, None)
|
||||
THREAD = (CHAT_ID, USER_ID, THREAD_ID)
|
||||
|
||||
|
||||
class TestStrategy:
|
||||
@pytest.mark.parametrize(
|
||||
"strategy,case,expected",
|
||||
[
|
||||
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
|
||||
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
|
||||
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
|
||||
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
|
||||
[FSMStrategy.CHAT, (42, 42), (42, 42)],
|
||||
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
|
||||
[FSMStrategy.USER_IN_CHAT, CHAT, CHAT],
|
||||
[FSMStrategy.USER_IN_CHAT, PRIVATE, PRIVATE],
|
||||
[FSMStrategy.USER_IN_CHAT, THREAD, CHAT],
|
||||
[FSMStrategy.CHAT, CHAT, (CHAT_ID, CHAT_ID, None)],
|
||||
[FSMStrategy.CHAT, PRIVATE, (USER_ID, USER_ID, None)],
|
||||
[FSMStrategy.CHAT, THREAD, (CHAT_ID, CHAT_ID, None)],
|
||||
[FSMStrategy.GLOBAL_USER, CHAT, PRIVATE],
|
||||
[FSMStrategy.GLOBAL_USER, PRIVATE, PRIVATE],
|
||||
[FSMStrategy.GLOBAL_USER, THREAD, PRIVATE],
|
||||
[FSMStrategy.USER_IN_THREAD, CHAT, CHAT],
|
||||
[FSMStrategy.USER_IN_THREAD, PRIVATE, PRIVATE],
|
||||
[FSMStrategy.USER_IN_THREAD, THREAD, THREAD],
|
||||
],
|
||||
)
|
||||
def test_strategy(self, strategy, case, expected):
|
||||
chat_id, user_id = case
|
||||
assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected
|
||||
chat_id, user_id, thread_id = case
|
||||
assert (
|
||||
apply_strategy(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
strategy=strategy,
|
||||
)
|
||||
== expected
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue