Rework FSM storage key

This commit is contained in:
Alex Root Junior 2021-10-11 01:30:19 +03:00
parent 8c4d4ef30a
commit 7c6cf3c122
10 changed files with 213 additions and 160 deletions

View file

@ -37,5 +37,5 @@ __all__ = (
"md", "md",
) )
__version__ = "3.0.0a17" __version__ = "3.0.0a18"
__api_version__ = "5.3" __api_version__ = "5.3"

View file

@ -1,44 +1,33 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
class FSMContext: class FSMContext:
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None: def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
self.bot = bot self.bot = bot
self.storage = storage self.storage = storage
self.chat_id = chat_id self.key = key
self.user_id = user_id
async def set_state(self, state: StateType = None) -> None: async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state( await self.storage.set_state(bot=self.bot, key=self.key, state=state)
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
async def get_state(self) -> Optional[str]: async def get_state(self) -> Optional[str]:
return await self.storage.get_state( return await self.storage.get_state(bot=self.bot, key=self.key)
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def set_data(self, data: Dict[str, Any]) -> None: async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data( await self.storage.set_data(bot=self.bot, key=self.key, data=data)
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
async def get_data(self) -> Dict[str, Any]: async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data( return await self.storage.get_data(bot=self.bot, key=self.key)
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def update_data( async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if data: if data:
kwargs.update(data) kwargs.update(data)
return await self.storage.update_data( return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
async def clear(self) -> None: async def clear(self) -> None:
await self.set_state(state=None) await self.set_state(state=None)

View file

@ -2,7 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
@ -31,21 +31,28 @@ class FSMContextMiddleware(BaseMiddleware):
if context: if context:
data.update({"state": context, "raw_state": await context.get_state()}) data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events: if self.isolate_events:
async with self.storage.lock( async with self.storage.lock(bot=bot, key=context.key):
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
return await handler(event, data) return await handler(event, data)
return await handler(event, data) return await handler(event, data)
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]: def resolve_event_context(
self,
bot: Bot,
data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
user = data.get("event_from_user") user = data.get("event_from_user")
chat = data.get("event_chat") chat = data.get("event_chat")
chat_id = chat.id if chat else None chat_id = chat.id if chat else None
user_id = user.id if user else None user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id) return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
def resolve_context( def resolve_context(
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int] self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
if chat_id is None: if chat_id is None:
chat_id = user_id chat_id = user_id
@ -54,8 +61,23 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id, user_id = apply_strategy( chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy chat_id=chat_id, user_id=user_id, strategy=self.strategy
) )
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id) return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None return None
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext: def get_context(
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id) self,
bot: Bot,
chat_id: int,
user_id: int,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
bot=bot,
storage=self.storage,
key=StorageKey(
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
destiny=destiny,
),
)

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot from aiogram import Bot
@ -7,45 +8,43 @@ from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]] StateType = Optional[Union[str, State]]
DEFAULT_DESTINY = "default"
@dataclass(frozen=True)
class StorageKey:
bot_id: int
chat_id: int
user_id: int
destiny: str = DEFAULT_DESTINY
class BaseStorage(ABC): class BaseStorage(ABC):
@abstractmethod @abstractmethod
@asynccontextmanager @asynccontextmanager
async def lock( async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
yield None yield None
@abstractmethod @abstractmethod
async def set_state( async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def get_state( async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def set_data( async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def get_data( async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
pass pass
async def update_data( async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any] current_data = await self.get_data(bot=bot, key=key)
) -> Dict[str, Any]:
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
current_data.update(data) current_data.update(data)
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data) await self.set_data(bot=bot, key=key, data=current_data)
return current_data.copy() return current_data.copy()
@abstractmethod @abstractmethod

View file

@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
@dataclass @dataclass
@ -18,30 +18,26 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage): class MemoryStorage(BaseStorage):
def __init__(self) -> None: def __init__(self) -> None:
self.storage: DefaultDict[ self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] MemoryStorageRecord
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord))) )
async def close(self) -> None: async def close(self) -> None:
pass pass
@asynccontextmanager @asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]: async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
async with self.storage[bot][chat_id][user_id].lock: async with self.storage[key].lock:
yield None yield None
async def set_state( async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None self.storage[key].state = state.state if isinstance(state, State) else state
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]: async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
return self.storage[bot][chat_id][user_id].state return self.storage[key].state
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None: async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
self.storage[bot][chat_id][user_id].data = data.copy() self.storage[key].data = data.copy()
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]: async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
return self.storage[bot][chat_id][user_id].data.copy() return self.storage[key].data.copy()

View file

@ -1,35 +1,67 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
from aioredis import ConnectionPool, Redis from aioredis import ConnectionPool, Redis
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60} DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class KeyBuilder(ABC):
"""
Base class for Redis key builder
"""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple Redis key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
self, prefix: str = "fsm", with_bot_id: bool = False, with_destiny: bool = False
) -> None:
self.prefix = prefix
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
if self.with_destiny:
parts.append(key.destiny)
parts.append(part)
return ":".join(parts)
class RedisStorage(BaseStorage): class RedisStorage(BaseStorage):
def __init__( def __init__(
self, self,
redis: Redis, redis: Redis,
prefix: str = "fsm", key_builder: Optional[KeyBuilder] = None,
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
state_ttl: Optional[int] = None, state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None, data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None, lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
if lock_kwargs is None: if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis self.redis = redis
self.prefix = prefix self.key_builder = key_builder
self.prefix_bot = prefix_bot
self.state_ttl = state_ttl self.state_ttl = state_ttl
self.data_ttl = data_ttl self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs self.lock_kwargs = lock_kwargs
@ -47,40 +79,28 @@ class RedisStorage(BaseStorage):
async def close(self) -> None: async def close(self) -> None:
await self.redis.close() # type: ignore await self.redis.close() # type: ignore
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager @asynccontextmanager
async def lock( async def lock(
self, bot: Bot, chat_id: int, user_id: int, state_lock_key: str = STATE_LOCK_KEY self,
bot: Bot,
key: StorageKey,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, state_lock_key) redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=key, **self.lock_kwargs): async with self.redis.lock(name=redis_key, **self.lock_kwargs):
yield None yield None
async def set_state( async def set_state(
self, self,
bot: Bot, bot: Bot,
chat_id: int, key: StorageKey,
user_id: int,
state: StateType = None, state: StateType = None,
state_key: str = STATE_KEY,
) -> None: ) -> None:
key = self.generate_key(bot, chat_id, user_id, state_key) redis_key = self.key_builder.build(key, "state")
if state is None: if state is None:
await self.redis.delete(key) await self.redis.delete(redis_key)
else: else:
await self.redis.set( await self.redis.set(
key, redis_key,
state.state if isinstance(state, State) else state, # type: ignore[arg-type] state.state if isinstance(state, State) else state, # type: ignore[arg-type]
ex=self.state_ttl, # type: ignore[arg-type] ex=self.state_ttl, # type: ignore[arg-type]
) )
@ -88,12 +108,10 @@ class RedisStorage(BaseStorage):
async def get_state( async def get_state(
self, self,
bot: Bot, bot: Bot,
chat_id: int, key: StorageKey,
user_id: int,
state_key: str = STATE_KEY,
) -> Optional[str]: ) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, state_key) redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(key) value = await self.redis.get(redis_key)
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode("utf-8") return value.decode("utf-8")
return cast(Optional[str], value) return cast(Optional[str], value)
@ -101,27 +119,26 @@ class RedisStorage(BaseStorage):
async def set_data( async def set_data(
self, self,
bot: Bot, bot: Bot,
chat_id: int, key: StorageKey,
user_id: int,
data: Dict[str, Any], data: Dict[str, Any],
state_data_key: str = STATE_DATA_KEY,
) -> None: ) -> None:
key = self.generate_key(bot, chat_id, user_id, state_data_key) redis_key = self.key_builder.build(key, "data")
if not data: if not data:
await self.redis.delete(key) await self.redis.delete(redis_key)
return return
json_data = bot.session.json_dumps(data) await self.redis.set(
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type] redis_key,
bot.session.json_dumps(data),
ex=self.data_ttl, # type: ignore[arg-type]
)
async def get_data( async def get_data(
self, self,
bot: Bot, bot: Bot,
chat_id: int, key: StorageKey,
user_id: int,
state_data_key: str = STATE_DATA_KEY,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, state_data_key) redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(key) value = await self.redis.get(redis_key)
if value is None: if value is None:
return {} return {}
if isinstance(value, bytes): if isinstance(value, bytes):

View file

@ -1,22 +1,33 @@
import pytest import pytest
from aiogram.dispatcher.fsm.storage.redis import RedisStorage from aiogram.dispatcher.fsm.storage.base import StorageKey
from tests.mocked_bot import MockedBot from aiogram.dispatcher.fsm.storage.redis import DefaultKeyBuilder
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
PREFIX = "test"
BOT_ID = 42
CHAT_ID = -1
USER_ID = 2
FIELD = "data"
DESTINY = "testing"
@pytest.mark.redis
class TestRedisStorage: class TestRedisDefaultKeyBuilder:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"prefix_bot,result", "with_bot_id,with_destiny,result",
[ [
[False, "fsm:-1:2"], [False, False, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{FIELD}"],
[True, "fsm:42:-1:2"], [True, False, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{FIELD}"],
[{42: "kaboom"}, "fsm:kaboom:-1:2"], [True, True, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
[lambda bot: "kaboom", "fsm:kaboom:-1:2"], [False, True, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
], ],
) )
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result): async def test_generate_key(self, with_bot_id: bool, with_destiny: bool, result: str):
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot) key_builder = DefaultKeyBuilder(
assert storage.generate_key(bot, -1, 2) == result prefix=PREFIX,
with_bot_id=with_bot_id,
with_destiny=with_destiny,
)
key = StorageKey(chat_id=CHAT_ID, user_id=USER_ID, bot_id=BOT_ID, destiny=DESTINY)
assert key_builder.build(key, FIELD) == result

View file

@ -1,46 +1,54 @@
import pytest import pytest
from aiogram.dispatcher.fsm.storage.base import BaseStorage from aiogram.dispatcher.fsm.storage.base import BaseStorage, StorageKey
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@pytest.fixture(name="storage_key")
def create_storate_key(bot: MockedBot):
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"storage", "storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")], [pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
) )
class TestStorages: class TestStorages:
async def test_lock(self, bot: MockedBot, storage: BaseStorage): async def test_lock(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
# TODO: ?!? # TODO: ?!?
async with storage.lock(bot=bot, chat_id=-42, user_id=42): async with storage.lock(bot=bot, key=storage_key):
assert True, "You are kidding me?" assert True, "You are kidding me?"
async def test_set_state(self, bot: MockedBot, storage: BaseStorage): async def test_set_state(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None assert await storage.get_state(bot=bot, key=storage_key) is None
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state") await storage.set_state(bot=bot, key=storage_key, state="state")
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state" assert await storage.get_state(bot=bot, key=storage_key) == "state"
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None) await storage.set_state(bot=bot, key=storage_key, state=None)
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None assert await storage.get_state(bot=bot, key=storage_key) is None
async def test_set_data(self, bot: MockedBot, storage: BaseStorage): async def test_set_data(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} assert await storage.get_data(bot=bot, key=storage_key) == {}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}) await storage.set_data(bot=bot, key=storage_key, data={"foo": "bar"})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"} assert await storage.get_data(bot=bot, key=storage_key) == {"foo": "bar"}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={}) await storage.set_data(bot=bot, key=storage_key, data={})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} assert await storage.get_data(bot=bot, key=storage_key) == {}
async def test_update_data(self, bot: MockedBot, storage: BaseStorage): async def test_update_data(
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey
assert await storage.update_data( ):
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"} assert await storage.get_data(bot=bot, key=storage_key) == {}
) == {"foo": "bar"} assert await storage.update_data(bot=bot, key=storage_key, data={"foo": "bar"}) == {
assert await storage.update_data( "foo": "bar"
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"} }
) == {"foo": "bar", "baz": "spam"} assert await storage.update_data(bot=bot, key=storage_key, data={"baz": "spam"}) == {
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == { "foo": "bar",
"baz": "spam",
}
assert await storage.get_data(bot=bot, key=storage_key) == {
"foo": "bar", "foo": "bar",
"baz": "spam", "baz": "spam",
} }

View file

@ -1,6 +1,7 @@
import pytest import pytest
from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import StorageKey
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
@ -10,21 +11,28 @@ pytestmark = pytest.mark.asyncio
@pytest.fixture() @pytest.fixture()
def state(bot: MockedBot): def state(bot: MockedBot):
storage = MemoryStorage() storage = MemoryStorage()
ctx = storage.storage[bot][-42][42] key = StorageKey(user_id=42, chat_id=-42, bot_id=bot.id)
ctx = storage.storage[key]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42) return FSMContext(bot=bot, storage=storage, key=key)
class TestFSMContext: class TestFSMContext:
async def test_address_mapping(self, bot: MockedBot): async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage() storage = MemoryStorage()
ctx = storage.storage[bot][-42][42] ctx = storage.storage[StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42) state = FSMContext(
state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42) bot=bot, storage=storage, key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69) )
state2 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=42, user_id=42, bot_id=bot.id)
)
state3 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=69, user_id=69, bot_id=bot.id)
)
assert await state.get_state() == "test" assert await state.get_state() == "test"
assert await state2.get_state() is None assert await state2.get_state() is None

View file

@ -4,6 +4,7 @@ import pytest
from aiogram import Dispatcher from aiogram import Dispatcher
from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import StorageKey
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from aiogram.types import Update, User from aiogram.types import Update, User
from aiogram.utils.i18n import ConstI18nMiddleware, FSMI18nMiddleware, I18n, SimpleI18nMiddleware from aiogram.utils.i18n import ConstI18nMiddleware, FSMI18nMiddleware, I18n, SimpleI18nMiddleware
@ -131,7 +132,9 @@ class TestFSMI18nMiddleware:
async def test_middleware(self, i18n: I18n, bot: MockedBot): async def test_middleware(self, i18n: I18n, bot: MockedBot):
middleware = FSMI18nMiddleware(i18n=i18n) middleware = FSMI18nMiddleware(i18n=i18n)
storage = MemoryStorage() storage = MemoryStorage()
state = FSMContext(bot=bot, storage=storage, user_id=42, chat_id=42) state = FSMContext(
bot=bot, storage=storage, key=StorageKey(user_id=42, chat_id=42, bot_id=bot.id)
)
data = { data = {
"event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"), "event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"),
"state": state, "state": state,