Cover 100%

This commit is contained in:
Alex Root Junior 2021-05-13 22:04:10 +03:00
parent 03ccebd8be
commit 9cf189ffd2
12 changed files with 143 additions and 29 deletions

View file

@ -4,7 +4,7 @@ import asyncio
import contextvars import contextvars
import warnings import warnings
from asyncio import CancelledError, Future, Lock from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
from .. import loggers from .. import loggers
from ..client.bot import Bot from ..client.bot import Bot
@ -378,5 +378,5 @@ class Dispatcher(Router):
# Allow to graceful shutdown # Allow to graceful shutdown
pass pass
def current_state(self, user_id: int, chat_id: int) -> FSMContext: def current_state(self, chat_id: int, user_id: int) -> FSMContext:
return self.fsm.get_context(user_id=user_id, chat_id=chat_id) return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id))

View file

@ -24,7 +24,7 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
event: Update, event: Update,
data: Dict[str, Any], data: Dict[str, Any],
) -> Any: ) -> Any:
context = self._resolve_context(data) context = self.resolve_event_context(data)
data["fsm_storage"] = self.storage data["fsm_storage"] = self.storage
if context: if context:
data.update({"state": context, "raw_state": await context.get_state()}) data.update({"state": context, "raw_state": await context.get_state()})
@ -33,12 +33,16 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
return await handler(event, data) return await handler(event, data)
return await handler(event, data) return await handler(event, data)
def _resolve_context(self, data: Dict[str, Any]) -> Optional[FSMContext]: def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]:
user = data.get("event_from_user") user = data.get("event_from_user")
chat = data.get("event_chat") chat = data.get("event_chat")
chat_id = chat.id if chat else None chat_id = chat.id if chat else None
user_id = user.id if user else None user_id = user.id if user else None
return self.resolve_context(chat_id=chat_id, user_id=user_id)
def resolve_context(
self, chat_id: Optional[int], user_id: Optional[int]
) -> Optional[FSMContext]:
if chat_id is None: if chat_id is None:
chat_id = user_id chat_id = user_id

View file

@ -1,5 +1,3 @@
import pytest
from aiogram.types import MessageEntity from aiogram.types import MessageEntity
from tests.deprecated import check_deprecated from tests.deprecated import check_deprecated

View file

@ -9,6 +9,7 @@ import pytest
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.fsm.strategy import FSMStrategy
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, SendMessage from aiogram.methods import GetMe, GetUpdates, SendMessage
@ -78,8 +79,9 @@ class TestDispatcher:
assert dp.parent_router is None assert dp.parent_router is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_feed_update(self): @pytest.mark.parametrize("isolate_events", (True, False))
dp = Dispatcher() async def test_feed_update(self, isolate_events):
dp = Dispatcher(isolate_events=isolate_events)
bot = Bot("42:TEST") bot = Bot("42:TEST")
@dp.message() @dp.message()
@ -652,3 +654,20 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records] log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0] assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
],
)
def test_get_current_state_context(self, strategy, case, expected):
dp = Dispatcher(fsm_strategy=strategy)
chat_id, user_id = case
state = dp.current_state(chat_id=chat_id, user_id=user_id)
assert (state.chat_id, state.user_id) == expected

View file

@ -32,3 +32,14 @@ class TestMemoryStorage:
assert 42 in storage.storage[-42] assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord) assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].data == {"foo": "bar"} assert storage.storage[-42][42].data == {"foo": "bar"}
@pytest.mark.asyncio
async def test_update_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == {
"foo": "bar"
}
assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == {
"foo": "bar",
"baz": "spam",
}

View file

@ -0,0 +1,49 @@
import pytest
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
@pytest.fixture()
def state():
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
return FSMContext(storage=storage, user_id=-42, chat_id=42)
class TestFSMContext:
@pytest.mark.asyncio
async def test_address_mapping(self):
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
state = FSMContext(storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(storage=storage, chat_id=69, user_id=69)
assert await state.get_state() == "test"
assert await state2.get_state() is None
assert await state3.get_state() is None
assert await state.get_data() == {"foo": "bar"}
assert await state2.get_data() == {}
assert await state3.get_data() == {}
await state2.set_state("experiments")
assert await state.get_state() == "test"
assert await state3.get_state() is None
await state3.set_data({"key": "value"})
assert await state2.get_data() == {}
await state.update_data({"key": "value"})
assert await state.get_data() == {"foo": "bar", "key": "value"}
await state.clear()
assert await state.get_state() is None
assert await state.get_data() == {}
assert await state2.get_state() == "experiments"

View file

@ -0,0 +1,20 @@
import pytest
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
class TestStrategy:
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
],
)
def test_strategy(self, strategy, case, expected):
chat_id, user_id = case
assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected

View file

@ -0,0 +1,32 @@
import datetime
from typing import Any
import pytest
from aiogram.dispatcher.handler.chat_member import ChatMemberHandler
from aiogram.types import Chat, ChatMember, ChatMemberUpdated, User
class TestChatMemberUpdated:
@pytest.mark.asyncio
async def test_attributes_aliases(self):
event = ChatMemberUpdated(
chat=Chat(id=42, type="private"),
from_user=User(id=42, is_bot=False, first_name="Test"),
date=datetime.datetime.now(),
old_chat_member=ChatMember(
user=User(id=42, is_bot=False, first_name="Test"), status="restricted"
),
new_chat_member=ChatMember(
user=User(id=42, is_bot=False, first_name="Test"), status="restricted"
),
)
class MyHandler(ChatMemberHandler):
async def handle(self) -> Any:
assert self.event == event
assert self.from_user == self.event.from_user
return True
assert await MyHandler(event)

View file

@ -1,26 +1,7 @@
import datetime
from typing import Any
import pytest import pytest
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler, skip from aiogram.dispatcher.event.bases import SkipHandler, skip
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.types import (
CallbackQuery,
Chat,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PollOption,
PreCheckoutQuery,
ShippingAddress,
ShippingQuery,
Update,
User,
)
from aiogram.utils.warnings import CodeHasNoEffect from aiogram.utils.warnings import CodeHasNoEffect
importable_router = Router() importable_router = Router()