Rework FSM storage key

This commit is contained in:
Alex Root Junior 2021-10-11 01:30:19 +03:00
parent 8c4d4ef30a
commit 7c6cf3c122
10 changed files with 213 additions and 160 deletions

View file

@ -37,5 +37,5 @@ __all__ = (
"md",
)
__version__ = "3.0.0a17"
__version__ = "3.0.0a18"
__api_version__ = "5.3"

View file

@ -1,44 +1,33 @@
from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
class FSMContext:
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
self.bot = bot
self.storage = storage
self.chat_id = chat_id
self.user_id = user_id
self.key = key
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
await self.storage.set_state(bot=self.bot, key=self.key, state=state)
async def get_state(self) -> Optional[str]:
return await self.storage.get_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
return await self.storage.get_state(bot=self.bot, key=self.key)
async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
await self.storage.set_data(bot=self.bot, key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
return await self.storage.get_data(bot=self.bot, key=self.key)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
if data:
kwargs.update(data)
return await self.storage.update_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
async def clear(self) -> None:
await self.set_state(state=None)

View file

@ -2,7 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import TelegramObject
@ -31,21 +31,28 @@ class FSMContextMiddleware(BaseMiddleware):
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
async with self.storage.lock(
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
async with self.storage.lock(bot=bot, key=context.key):
return await handler(event, data)
return await handler(event, data)
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
def resolve_event_context(
self,
bot: Bot,
data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
def resolve_context(
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
@ -54,8 +61,23 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)
def get_context(
self,
bot: Bot,
chat_id: int,
user_id: int,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
bot=bot,
storage=self.storage,
key=StorageKey(
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
destiny=destiny,
),
)

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
@ -7,45 +8,43 @@ from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]]
DEFAULT_DESTINY = "default"
@dataclass(frozen=True)
class StorageKey:
bot_id: int
chat_id: int
user_id: int
destiny: str = DEFAULT_DESTINY
class BaseStorage(ABC):
@abstractmethod
@asynccontextmanager
async def lock(
self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
yield None
@abstractmethod
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
pass
@abstractmethod
async def get_state(
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
pass
@abstractmethod
async def set_data(
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
pass
@abstractmethod
async def get_data(
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
pass
async def update_data(
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]:
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
current_data = await self.get_data(bot=bot, key=key)
current_data.update(data)
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
await self.set_data(bot=bot, key=key, data=current_data)
return current_data.copy()
@abstractmethod

View file

@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
@dataclass
@ -18,30 +18,26 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage):
def __init__(self) -> None:
self.storage: DefaultDict[
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord
)
async def close(self) -> None:
pass
@asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[bot][chat_id][user_id].lock:
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
async with self.storage[key].lock:
yield None
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[bot][chat_id][user_id].state
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
return self.storage[key].state
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[bot][chat_id][user_id].data = data.copy()
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
self.storage[key].data = data.copy()
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[bot][chat_id][user_id].data.copy()
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
return self.storage[key].data.copy()

View file

@ -1,35 +1,67 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
from aioredis import ConnectionPool, Redis
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class KeyBuilder(ABC):
"""
Base class for Redis key builder
"""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple Redis key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
self, prefix: str = "fsm", with_bot_id: bool = False, with_destiny: bool = False
) -> None:
self.prefix = prefix
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
if self.with_destiny:
parts.append(key.destiny)
parts.append(part)
return ":".join(parts)
class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
prefix: str = "fsm",
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.prefix = prefix
self.prefix_bot = prefix_bot
self.key_builder = key_builder
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs
@ -47,40 +79,28 @@ class RedisStorage(BaseStorage):
async def close(self) -> None:
await self.redis.close() # type: ignore
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager
async def lock(
self, bot: Bot, chat_id: int, user_id: int, state_lock_key: str = STATE_LOCK_KEY
self,
bot: Bot,
key: StorageKey,
) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, state_lock_key)
async with self.redis.lock(name=key, **self.lock_kwargs):
redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
yield None
async def set_state(
self,
bot: Bot,
chat_id: int,
user_id: int,
key: StorageKey,
state: StateType = None,
state_key: str = STATE_KEY,
) -> None:
key = self.generate_key(bot, chat_id, user_id, state_key)
redis_key = self.key_builder.build(key, "state")
if state is None:
await self.redis.delete(key)
await self.redis.delete(redis_key)
else:
await self.redis.set(
key,
redis_key,
state.state if isinstance(state, State) else state, # type: ignore[arg-type]
ex=self.state_ttl, # type: ignore[arg-type]
)
@ -88,12 +108,10 @@ class RedisStorage(BaseStorage):
async def get_state(
self,
bot: Bot,
chat_id: int,
user_id: int,
state_key: str = STATE_KEY,
key: StorageKey,
) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, state_key)
value = await self.redis.get(key)
redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
@ -101,27 +119,26 @@ class RedisStorage(BaseStorage):
async def set_data(
self,
bot: Bot,
chat_id: int,
user_id: int,
key: StorageKey,
data: Dict[str, Any],
state_data_key: str = STATE_DATA_KEY,
) -> None:
key = self.generate_key(bot, chat_id, user_id, state_data_key)
redis_key = self.key_builder.build(key, "data")
if not data:
await self.redis.delete(key)
await self.redis.delete(redis_key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]
await self.redis.set(
redis_key,
bot.session.json_dumps(data),
ex=self.data_ttl, # type: ignore[arg-type]
)
async def get_data(
self,
bot: Bot,
chat_id: int,
user_id: int,
state_data_key: str = STATE_DATA_KEY,
key: StorageKey,
) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, state_data_key)
value = await self.redis.get(key)
redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key)
if value is None:
return {}
if isinstance(value, bytes):

View file

@ -1,22 +1,33 @@
import pytest
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
from aiogram.dispatcher.fsm.storage.base import StorageKey
from aiogram.dispatcher.fsm.storage.redis import DefaultKeyBuilder
pytestmark = pytest.mark.asyncio
PREFIX = "test"
BOT_ID = 42
CHAT_ID = -1
USER_ID = 2
FIELD = "data"
DESTINY = "testing"
@pytest.mark.redis
class TestRedisStorage:
class TestRedisDefaultKeyBuilder:
@pytest.mark.parametrize(
"prefix_bot,result",
"with_bot_id,with_destiny,result",
[
[False, "fsm:-1:2"],
[True, "fsm:42:-1:2"],
[{42: "kaboom"}, "fsm:kaboom:-1:2"],
[lambda bot: "kaboom", "fsm:kaboom:-1:2"],
[False, False, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{FIELD}"],
[True, False, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{FIELD}"],
[True, True, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
[False, True, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
],
)
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
assert storage.generate_key(bot, -1, 2) == result
async def test_generate_key(self, with_bot_id: bool, with_destiny: bool, result: str):
key_builder = DefaultKeyBuilder(
prefix=PREFIX,
with_bot_id=with_bot_id,
with_destiny=with_destiny,
)
key = StorageKey(chat_id=CHAT_ID, user_id=USER_ID, bot_id=BOT_ID, destiny=DESTINY)
assert key_builder.build(key, FIELD) == result

View file

@ -1,46 +1,54 @@
import pytest
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StorageKey
from tests.mocked_bot import MockedBot
pytestmark = pytest.mark.asyncio
@pytest.fixture(name="storage_key")
def create_storate_key(bot: MockedBot):
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
@pytest.mark.parametrize(
"storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
)
class TestStorages:
async def test_lock(self, bot: MockedBot, storage: BaseStorage):
async def test_lock(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
# TODO: ?!?
async with storage.lock(bot=bot, chat_id=-42, user_id=42):
async with storage.lock(bot=bot, key=storage_key):
assert True, "You are kidding me?"
async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
async def test_set_state(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_state(bot=bot, key=storage_key) is None
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
await storage.set_state(bot=bot, key=storage_key, state="state")
assert await storage.get_state(bot=bot, key=storage_key) == "state"
await storage.set_state(bot=bot, key=storage_key, state=None)
assert await storage.get_state(bot=bot, key=storage_key) is None
async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
async def test_set_data(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_data(bot=bot, key=storage_key) == {}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
await storage.set_data(bot=bot, key=storage_key, data={"foo": "bar"})
assert await storage.get_data(bot=bot, key=storage_key) == {"foo": "bar"}
await storage.set_data(bot=bot, key=storage_key, data={})
assert await storage.get_data(bot=bot, key=storage_key) == {}
async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
) == {"foo": "bar"}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
) == {"foo": "bar", "baz": "spam"}
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {
async def test_update_data(
self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey
):
assert await storage.get_data(bot=bot, key=storage_key) == {}
assert await storage.update_data(bot=bot, key=storage_key, data={"foo": "bar"}) == {
"foo": "bar"
}
assert await storage.update_data(bot=bot, key=storage_key, data={"baz": "spam"}) == {
"foo": "bar",
"baz": "spam",
}
assert await storage.get_data(bot=bot, key=storage_key) == {
"foo": "bar",
"baz": "spam",
}

View file

@ -1,6 +1,7 @@
import pytest
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import StorageKey
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from tests.mocked_bot import MockedBot
@ -10,21 +11,28 @@ pytestmark = pytest.mark.asyncio
@pytest.fixture()
def state(bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[bot][-42][42]
key = StorageKey(user_id=42, chat_id=-42, bot_id=bot.id)
ctx = storage.storage[key]
ctx.state = "test"
ctx.data = {"foo": "bar"}
return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
return FSMContext(bot=bot, storage=storage, key=key)
class TestFSMContext:
async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[bot][-42][42]
ctx = storage.storage[StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)]
ctx.state = "test"
ctx.data = {"foo": "bar"}
state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
state = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
)
state2 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=42, user_id=42, bot_id=bot.id)
)
state3 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=69, user_id=69, bot_id=bot.id)
)
assert await state.get_state() == "test"
assert await state2.get_state() is None

View file

@ -4,6 +4,7 @@ import pytest
from aiogram import Dispatcher
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import StorageKey
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from aiogram.types import Update, User
from aiogram.utils.i18n import ConstI18nMiddleware, FSMI18nMiddleware, I18n, SimpleI18nMiddleware
@ -131,7 +132,9 @@ class TestFSMI18nMiddleware:
async def test_middleware(self, i18n: I18n, bot: MockedBot):
middleware = FSMI18nMiddleware(i18n=i18n)
storage = MemoryStorage()
state = FSMContext(bot=bot, storage=storage, user_id=42, chat_id=42)
state = FSMContext(
bot=bot, storage=storage, key=StorageKey(user_id=42, chat_id=42, bot_id=bot.id)
)
data = {
"event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"),
"state": state,