Forum topic in FSM (#1161)

* Base implementation

* Added tests, fixed arguments priority

* Use `Optional[X]` instead of `X | None`

* Added changelog

* Added tests
This commit is contained in:
Alex Root Junior 2023-04-22 19:35:41 +03:00 committed by GitHub
parent 1538bc2e2d
commit 942ba0d520
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 164 additions and 60 deletions

17
CHANGES/1161.feature.rst Normal file
View file

@ -0,0 +1,17 @@
Added support for FSM in Forum topics.
The strategy can be changed in dispatcher:
.. code-block:: python
from aiogram.fsm.strategy import FSMStrategy
...
dispatcher = Dispatcher(
fsm_strategy=FSMStrategy.USER_IN_THREAD,
storage=..., # Any persistent storage
)
.. note::
If you have implemented you own storages you should extend record key generation
with new one attribute - `thread_id`

View file

@ -4,6 +4,10 @@ from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, TelegramObject, Update, User
EVENT_FROM_USER_KEY = "event_from_user"
EVENT_CHAT_KEY = "event_chat"
EVENT_THREAD_ID_KEY = "event_thread_id"
class UserContextMiddleware(BaseMiddleware):
async def __call__(
@ -14,61 +18,64 @@ class UserContextMiddleware(BaseMiddleware):
) -> Any:
if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user):
if user is not None:
data["event_from_user"] = user
if chat is not None:
data["event_chat"] = chat
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
chat, user, thread_id = self.resolve_event_context(event=event)
if user is not None:
data[EVENT_FROM_USER_KEY] = user
if chat is not None:
data[EVENT_CHAT_KEY] = chat
if thread_id is not None:
data[EVENT_THREAD_ID_KEY] = thread_id
return await handler(event, data)
@classmethod
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
def resolve_event_context(
cls, event: Update
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
"""
Resolve chat and user instance from Update object
"""
if event.message:
return event.message.chat, event.message.from_user
return (
event.message.chat,
event.message.from_user,
event.message.message_thread_id if event.message.is_topic_message else None,
)
if event.edited_message:
return event.edited_message.chat, event.edited_message.from_user
return (
event.edited_message.chat,
event.edited_message.from_user,
event.edited_message.message_thread_id
if event.edited_message.is_topic_message
else None,
)
if event.channel_post:
return event.channel_post.chat, None
return event.channel_post.chat, None, None
if event.edited_channel_post:
return event.edited_channel_post.chat, None
return event.edited_channel_post.chat, None, None
if event.inline_query:
return None, event.inline_query.from_user
return None, event.inline_query.from_user, None
if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user
return None, event.chosen_inline_result.from_user, None
if event.callback_query:
if event.callback_query.message:
return event.callback_query.message.chat, event.callback_query.from_user
return None, event.callback_query.from_user
return (
event.callback_query.message.chat,
event.callback_query.from_user,
event.callback_query.message.message_thread_id
if event.callback_query.message.is_topic_message
else None,
)
return None, event.callback_query.from_user, None
if event.shipping_query:
return None, event.shipping_query.from_user
return None, event.shipping_query.from_user, None
if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user
return None, event.pre_checkout_query.from_user, None
if event.poll_answer:
return None, event.poll_answer.user
return None, event.poll_answer.user, None
if event.my_chat_member:
return event.my_chat_member.chat, event.my_chat_member.from_user
return event.my_chat_member.chat, event.my_chat_member.from_user, None
if event.chat_member:
return event.chat_member.chat, event.chat_member.from_user
return event.chat_member.chat, event.chat_member.from_user, None
if event.chat_join_request:
return event.chat_join_request.chat, event.chat_join_request.from_user
return None, None
return event.chat_join_request.chat, event.chat_join_request.from_user, None
return None, None, None

View file

@ -47,25 +47,42 @@ class FSMContextMiddleware(BaseMiddleware):
) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
thread_id = data.get("event_thread_id")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return self.resolve_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
)
def resolve_context(
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
if chat_id is not None and user_id is not None:
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
chat_id, user_id, thread_id = apply_strategy(
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=self.strategy,
)
return self.get_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None
def get_context(
@ -73,6 +90,7 @@ class FSMContextMiddleware(BaseMiddleware):
bot: Bot,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
@ -81,6 +99,7 @@ class FSMContextMiddleware(BaseMiddleware):
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
thread_id=thread_id,
destiny=destiny,
),
)

View file

@ -15,6 +15,7 @@ class StorageKey:
bot_id: int
chat_id: int
user_id: int
thread_id: Optional[int] = None
destiny: str = DEFAULT_DESTINY

View file

@ -70,7 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:

View file

@ -1,16 +1,24 @@
from enum import Enum, auto
from typing import Tuple
from typing import Optional, Tuple
class FSMStrategy(Enum):
USER_IN_CHAT = auto()
CHAT = auto()
GLOBAL_USER = auto()
USER_IN_THREAD = auto()
def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]:
def apply_strategy(
strategy: FSMStrategy,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
) -> Tuple[int, int, Optional[int]]:
if strategy == FSMStrategy.CHAT:
return chat_id, chat_id
return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER:
return user_id, user_id
return chat_id, user_id
return user_id, user_id, None
if strategy == FSMStrategy.USER_IN_THREAD:
return chat_id, user_id, thread_id
return chat_id, user_id, None

View file

@ -14,7 +14,7 @@ from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod
from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod
from aiogram.types import (
CallbackQuery,
Chat,
@ -462,9 +462,9 @@ class TestDispatcher:
async def my_handler(event: Any, **kwargs: Any):
assert event == getattr(update, event_type)
if has_chat:
assert Chat.get_current(False)
assert kwargs["event_chat"]
if has_user:
assert User.get_current(False)
assert kwargs["event_from_user"]
return kwargs
result = await router.feed_update(bot, update, test="PASS")

View file

@ -1,6 +1,9 @@
from unittest.mock import patch
import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.types import Update
async def next_handler(*args, **kwargs):
@ -11,3 +14,13 @@ class TestUserContextMiddleware:
async def test_unexpected_event_type(self):
with pytest.raises(RuntimeError):
await UserContextMiddleware()(next_handler, object(), {})
async def test_call(self):
middleware = UserContextMiddleware()
data = {}
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
await middleware(next_handler, Update(update_id=42), data)
assert data["event_chat"] == 1
assert data["event_from_user"] == 2
assert data["event_thread_id"] == 3

View file

@ -11,6 +11,7 @@ PREFIX = "test"
BOT_ID = 42
CHAT_ID = -1
USER_ID = 2
THREAD_ID = 3
FIELD = "data"
@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder:
with pytest.raises(ValueError):
key_builder.build(key, FIELD)
def test_thread_id(self):
key_builder = DefaultKeyBuilder(
prefix=PREFIX,
)
key = StorageKey(
chat_id=CHAT_ID,
user_id=USER_ID,
bot_id=BOT_ID,
thread_id=THREAD_ID,
destiny=DEFAULT_DESTINY,
)
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
def test_create_isolation(self):
fake_redis = object()
storage = RedisStorage(redis=fake_redis)

View file

@ -2,19 +2,41 @@ import pytest
from aiogram.fsm.strategy import FSMStrategy, apply_strategy
CHAT_ID = -42
USER_ID = 42
THREAD_ID = 1
PRIVATE = (USER_ID, USER_ID, None)
CHAT = (CHAT_ID, USER_ID, None)
THREAD = (CHAT_ID, USER_ID, THREAD_ID)
class TestStrategy:
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, CHAT, CHAT],
[FSMStrategy.USER_IN_CHAT, PRIVATE, PRIVATE],
[FSMStrategy.USER_IN_CHAT, THREAD, CHAT],
[FSMStrategy.CHAT, CHAT, (CHAT_ID, CHAT_ID, None)],
[FSMStrategy.CHAT, PRIVATE, (USER_ID, USER_ID, None)],
[FSMStrategy.CHAT, THREAD, (CHAT_ID, CHAT_ID, None)],
[FSMStrategy.GLOBAL_USER, CHAT, PRIVATE],
[FSMStrategy.GLOBAL_USER, PRIVATE, PRIVATE],
[FSMStrategy.GLOBAL_USER, THREAD, PRIVATE],
[FSMStrategy.USER_IN_THREAD, CHAT, CHAT],
[FSMStrategy.USER_IN_THREAD, PRIVATE, PRIVATE],
[FSMStrategy.USER_IN_THREAD, THREAD, THREAD],
],
)
def test_strategy(self, strategy, case, expected):
chat_id, user_id = case
assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected
chat_id, user_id, thread_id = case
assert (
apply_strategy(
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=strategy,
)
== expected
)