diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 5f3f48fd..2978d5bd 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -4,7 +4,7 @@ import asyncio import contextvars import warnings from asyncio import CancelledError, Future, Lock -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, Optional, Union, cast from .. import loggers from ..client.bot import Bot @@ -378,5 +378,5 @@ class Dispatcher(Router): # Allow to graceful shutdown pass - def current_state(self, user_id: int, chat_id: int) -> FSMContext: - return self.fsm.get_context(user_id=user_id, chat_id=chat_id) + def current_state(self, chat_id: int, user_id: int) -> FSMContext: + return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id)) diff --git a/aiogram/dispatcher/fsm/middleware.py b/aiogram/dispatcher/fsm/middleware.py index 0d9bfac0..d3d5d8c2 100644 --- a/aiogram/dispatcher/fsm/middleware.py +++ b/aiogram/dispatcher/fsm/middleware.py @@ -24,7 +24,7 @@ class FSMContextMiddleware(BaseMiddleware[Update]): event: Update, data: Dict[str, Any], ) -> Any: - context = self._resolve_context(data) + context = self.resolve_event_context(data) data["fsm_storage"] = self.storage if context: data.update({"state": context, "raw_state": await context.get_state()}) @@ -33,12 +33,16 @@ class FSMContextMiddleware(BaseMiddleware[Update]): return await handler(event, data) return await handler(event, data) - def _resolve_context(self, data: Dict[str, Any]) -> Optional[FSMContext]: + def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]: user = data.get("event_from_user") chat = data.get("event_chat") chat_id = chat.id if chat else None user_id = user.id if user else None + return self.resolve_context(chat_id=chat_id, user_id=user_id) + def resolve_context( + self, chat_id: Optional[int], user_id: Optional[int] + ) -> Optional[FSMContext]: if chat_id is None: chat_id = user_id diff --git a/tests/test_api/test_types/test_message_entity.py b/tests/test_api/test_types/test_message_entity.py index d5195d98..363824a5 100644 --- a/tests/test_api/test_types/test_message_entity.py +++ b/tests/test_api/test_types/test_message_entity.py @@ -1,5 +1,3 @@ -import pytest - from aiogram.types import MessageEntity from tests.deprecated import check_deprecated diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index da04527b..ecf44712 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -9,6 +9,7 @@ import pytest from aiogram import Bot from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler +from aiogram.dispatcher.fsm.strategy import FSMStrategy from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.router import Router from aiogram.methods import GetMe, GetUpdates, SendMessage @@ -78,8 +79,9 @@ class TestDispatcher: assert dp.parent_router is None @pytest.mark.asyncio - async def test_feed_update(self): - dp = Dispatcher() + @pytest.mark.parametrize("isolate_events", (True, False)) + async def test_feed_update(self, isolate_events): + dp = Dispatcher(isolate_events=isolate_events) bot = Bot("42:TEST") @dp.message() @@ -652,3 +654,20 @@ class TestDispatcher: log_records = [rec.message for rec in caplog.records] assert "Cause exception while process update" in log_records[0] + + @pytest.mark.parametrize( + "strategy,case,expected", + [ + [FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)], + [FSMStrategy.CHAT, (-42, 42), (-42, -42)], + [FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)], + [FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)], + [FSMStrategy.CHAT, (42, 42), (42, 42)], + [FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)], + ], + ) + def test_get_current_state_context(self, strategy, case, expected): + dp = Dispatcher(fsm_strategy=strategy) + chat_id, user_id = case + state = dp.current_state(chat_id=chat_id, user_id=user_id) + assert (state.chat_id, state.user_id) == expected diff --git a/tests/test_dispatcher/fsm/__init__.py b/tests/test_dispatcher/test_fsm/__init__.py similarity index 100% rename from tests/test_dispatcher/fsm/__init__.py rename to tests/test_dispatcher/test_fsm/__init__.py diff --git a/tests/test_dispatcher/fsm/storage/__init__.py b/tests/test_dispatcher/test_fsm/storage/__init__.py similarity index 100% rename from tests/test_dispatcher/fsm/storage/__init__.py rename to tests/test_dispatcher/test_fsm/storage/__init__.py diff --git a/tests/test_dispatcher/fsm/storage/test_base.py b/tests/test_dispatcher/test_fsm/storage/test_memory.py similarity index 73% rename from tests/test_dispatcher/fsm/storage/test_base.py rename to tests/test_dispatcher/test_fsm/storage/test_memory.py index 8b7129a4..2f587075 100644 --- a/tests/test_dispatcher/fsm/storage/test_base.py +++ b/tests/test_dispatcher/test_fsm/storage/test_memory.py @@ -32,3 +32,14 @@ class TestMemoryStorage: assert 42 in storage.storage[-42] assert isinstance(storage.storage[-42][42], MemoryStorageRecord) assert storage.storage[-42][42].data == {"foo": "bar"} + + @pytest.mark.asyncio + async def test_update_data(self, storage: MemoryStorage): + assert await storage.get_data(chat_id=-42, user_id=42) == {} + assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == { + "foo": "bar" + } + assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == { + "foo": "bar", + "baz": "spam", + } diff --git a/tests/test_dispatcher/test_fsm/test_context.py b/tests/test_dispatcher/test_fsm/test_context.py new file mode 100644 index 00000000..6c444c44 --- /dev/null +++ b/tests/test_dispatcher/test_fsm/test_context.py @@ -0,0 +1,49 @@ +import pytest + +from aiogram.dispatcher.fsm.context import FSMContext +from aiogram.dispatcher.fsm.storage.memory import MemoryStorage + + +@pytest.fixture() +def state(): + storage = MemoryStorage() + ctx = storage.storage[-42][42] + ctx.state = "test" + ctx.data = {"foo": "bar"} + return FSMContext(storage=storage, user_id=-42, chat_id=42) + + +class TestFSMContext: + @pytest.mark.asyncio + async def test_address_mapping(self): + storage = MemoryStorage() + ctx = storage.storage[-42][42] + ctx.state = "test" + ctx.data = {"foo": "bar"} + state = FSMContext(storage=storage, chat_id=-42, user_id=42) + state2 = FSMContext(storage=storage, chat_id=42, user_id=42) + state3 = FSMContext(storage=storage, chat_id=69, user_id=69) + + assert await state.get_state() == "test" + assert await state2.get_state() is None + assert await state3.get_state() is None + + assert await state.get_data() == {"foo": "bar"} + assert await state2.get_data() == {} + assert await state3.get_data() == {} + + await state2.set_state("experiments") + assert await state.get_state() == "test" + assert await state3.get_state() is None + + await state3.set_data({"key": "value"}) + assert await state2.get_data() == {} + + await state.update_data({"key": "value"}) + assert await state.get_data() == {"foo": "bar", "key": "value"} + + await state.clear() + assert await state.get_state() is None + assert await state.get_data() == {} + + assert await state2.get_state() == "experiments" diff --git a/tests/test_dispatcher/fsm/test_state.py b/tests/test_dispatcher/test_fsm/test_state.py similarity index 100% rename from tests/test_dispatcher/fsm/test_state.py rename to tests/test_dispatcher/test_fsm/test_state.py diff --git a/tests/test_dispatcher/test_fsm/test_strategy.py b/tests/test_dispatcher/test_fsm/test_strategy.py new file mode 100644 index 00000000..5a297679 --- /dev/null +++ b/tests/test_dispatcher/test_fsm/test_strategy.py @@ -0,0 +1,20 @@ +import pytest + +from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy + + +class TestStrategy: + @pytest.mark.parametrize( + "strategy,case,expected", + [ + [FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)], + [FSMStrategy.CHAT, (-42, 42), (-42, -42)], + [FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)], + [FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)], + [FSMStrategy.CHAT, (42, 42), (42, 42)], + [FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)], + ], + ) + def test_strategy(self, strategy, case, expected): + chat_id, user_id = case + assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected diff --git a/tests/test_dispatcher/test_handler/test_chat_member.py b/tests/test_dispatcher/test_handler/test_chat_member.py new file mode 100644 index 00000000..baf1ee85 --- /dev/null +++ b/tests/test_dispatcher/test_handler/test_chat_member.py @@ -0,0 +1,32 @@ +import datetime +from typing import Any + +import pytest + +from aiogram.dispatcher.handler.chat_member import ChatMemberHandler +from aiogram.types import Chat, ChatMember, ChatMemberUpdated, User + + +class TestChatMemberUpdated: + @pytest.mark.asyncio + async def test_attributes_aliases(self): + event = ChatMemberUpdated( + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + date=datetime.datetime.now(), + old_chat_member=ChatMember( + user=User(id=42, is_bot=False, first_name="Test"), status="restricted" + ), + new_chat_member=ChatMember( + user=User(id=42, is_bot=False, first_name="Test"), status="restricted" + ), + ) + + class MyHandler(ChatMemberHandler): + async def handle(self) -> Any: + assert self.event == event + assert self.from_user == self.event.from_user + + return True + + assert await MyHandler(event) diff --git a/tests/test_dispatcher/test_router.py b/tests/test_dispatcher/test_router.py index 093715b1..c84239b1 100644 --- a/tests/test_dispatcher/test_router.py +++ b/tests/test_dispatcher/test_router.py @@ -1,26 +1,7 @@ -import datetime -from typing import Any - import pytest -from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler, skip -from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware +from aiogram.dispatcher.event.bases import SkipHandler, skip from aiogram.dispatcher.router import Router -from aiogram.types import ( - CallbackQuery, - Chat, - ChosenInlineResult, - InlineQuery, - Message, - Poll, - PollAnswer, - PollOption, - PreCheckoutQuery, - ShippingAddress, - ShippingQuery, - Update, - User, -) from aiogram.utils.warnings import CodeHasNoEffect importable_router = Router()