mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 02:03:04 +00:00
Rework FSM storage key
This commit is contained in:
parent
8c4d4ef30a
commit
7c6cf3c122
10 changed files with 213 additions and 160 deletions
|
|
@ -37,5 +37,5 @@ __all__ = (
|
||||||
"md",
|
"md",
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "3.0.0a17"
|
__version__ = "3.0.0a18"
|
||||||
__api_version__ = "5.3"
|
__api_version__ = "5.3"
|
||||||
|
|
|
||||||
|
|
@ -1,44 +1,33 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||||
|
|
||||||
|
|
||||||
class FSMContext:
|
class FSMContext:
|
||||||
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
|
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.chat_id = chat_id
|
self.key = key
|
||||||
self.user_id = user_id
|
|
||||||
|
|
||||||
async def set_state(self, state: StateType = None) -> None:
|
async def set_state(self, state: StateType = None) -> None:
|
||||||
await self.storage.set_state(
|
await self.storage.set_state(bot=self.bot, key=self.key, state=state)
|
||||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_state(self) -> Optional[str]:
|
async def get_state(self) -> Optional[str]:
|
||||||
return await self.storage.get_state(
|
return await self.storage.get_state(bot=self.bot, key=self.key)
|
||||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_data(self, data: Dict[str, Any]) -> None:
|
async def set_data(self, data: Dict[str, Any]) -> None:
|
||||||
await self.storage.set_data(
|
await self.storage.set_data(bot=self.bot, key=self.key, data=data)
|
||||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_data(self) -> Dict[str, Any]:
|
async def get_data(self) -> Dict[str, Any]:
|
||||||
return await self.storage.get_data(
|
return await self.storage.get_data(bot=self.bot, key=self.key)
|
||||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_data(
|
async def update_data(
|
||||||
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
|
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if data:
|
if data:
|
||||||
kwargs.update(data)
|
kwargs.update(data)
|
||||||
return await self.storage.update_data(
|
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
|
||||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def clear(self) -> None:
|
async def clear(self) -> None:
|
||||||
await self.set_state(state=None)
|
await self.set_state(state=None)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.fsm.context import FSMContext
|
from aiogram.dispatcher.fsm.context import FSMContext
|
||||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage
|
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
|
||||||
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
|
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
@ -31,21 +31,28 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
if context:
|
if context:
|
||||||
data.update({"state": context, "raw_state": await context.get_state()})
|
data.update({"state": context, "raw_state": await context.get_state()})
|
||||||
if self.isolate_events:
|
if self.isolate_events:
|
||||||
async with self.storage.lock(
|
async with self.storage.lock(bot=bot, key=context.key):
|
||||||
bot=bot, chat_id=context.chat_id, user_id=context.user_id
|
|
||||||
):
|
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
||||||
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
|
def resolve_event_context(
|
||||||
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
destiny: str = DEFAULT_DESTINY,
|
||||||
|
) -> Optional[FSMContext]:
|
||||||
user = data.get("event_from_user")
|
user = data.get("event_from_user")
|
||||||
chat = data.get("event_chat")
|
chat = data.get("event_chat")
|
||||||
chat_id = chat.id if chat else None
|
chat_id = chat.id if chat else None
|
||||||
user_id = user.id if user else None
|
user_id = user.id if user else None
|
||||||
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
|
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||||
|
|
||||||
def resolve_context(
|
def resolve_context(
|
||||||
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
chat_id: Optional[int],
|
||||||
|
user_id: Optional[int],
|
||||||
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> Optional[FSMContext]:
|
) -> Optional[FSMContext]:
|
||||||
if chat_id is None:
|
if chat_id is None:
|
||||||
chat_id = user_id
|
chat_id = user_id
|
||||||
|
|
@ -54,8 +61,23 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
chat_id, user_id = apply_strategy(
|
chat_id, user_id = apply_strategy(
|
||||||
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
||||||
)
|
)
|
||||||
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
|
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
|
def get_context(
|
||||||
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
chat_id: int,
|
||||||
|
user_id: int,
|
||||||
|
destiny: str = DEFAULT_DESTINY,
|
||||||
|
) -> FSMContext:
|
||||||
|
return FSMContext(
|
||||||
|
bot=bot,
|
||||||
|
storage=self.storage,
|
||||||
|
key=StorageKey(
|
||||||
|
user_id=user_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
bot_id=bot.id,
|
||||||
|
destiny=destiny,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
|
|
@ -7,45 +8,43 @@ from aiogram.dispatcher.fsm.state import State
|
||||||
|
|
||||||
StateType = Optional[Union[str, State]]
|
StateType = Optional[Union[str, State]]
|
||||||
|
|
||||||
|
DEFAULT_DESTINY = "default"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StorageKey:
|
||||||
|
bot_id: int
|
||||||
|
chat_id: int
|
||||||
|
user_id: int
|
||||||
|
destiny: str = DEFAULT_DESTINY
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage(ABC):
|
class BaseStorage(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lock(
|
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||||
self, bot: Bot, chat_id: int, user_id: int
|
|
||||||
) -> AsyncGenerator[None, None]: # pragma: no cover
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def set_state(
|
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||||
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
|
|
||||||
) -> None: # pragma: no cover
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_state(
|
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||||
self, bot: Bot, chat_id: int, user_id: int
|
|
||||||
) -> Optional[str]: # pragma: no cover
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def set_data(
|
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||||
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
|
|
||||||
) -> None: # pragma: no cover
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_data(
|
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||||
self, bot: Bot, chat_id: int, user_id: int
|
|
||||||
) -> Dict[str, Any]: # pragma: no cover
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_data(
|
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
|
current_data = await self.get_data(bot=bot, key=key)
|
||||||
) -> Dict[str, Any]:
|
|
||||||
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
|
|
||||||
current_data.update(data)
|
current_data.update(data)
|
||||||
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
|
await self.set_data(bot=bot, key=key, data=current_data)
|
||||||
return current_data.copy()
|
return current_data.copy()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.fsm.state import State
|
from aiogram.dispatcher.fsm.state import State
|
||||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -18,30 +18,26 @@ class MemoryStorageRecord:
|
||||||
|
|
||||||
class MemoryStorage(BaseStorage):
|
class MemoryStorage(BaseStorage):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.storage: DefaultDict[
|
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
|
||||||
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
|
MemoryStorageRecord
|
||||||
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
|
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||||
async with self.storage[bot][chat_id][user_id].lock:
|
async with self.storage[key].lock:
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
async def set_state(
|
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||||
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
|
self.storage[key].state = state.state if isinstance(state, State) else state
|
||||||
) -> None:
|
|
||||||
self.storage[bot][chat_id][user_id].state = (
|
|
||||||
state.state if isinstance(state, State) else state
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
|
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||||
return self.storage[bot][chat_id][user_id].state
|
return self.storage[key].state
|
||||||
|
|
||||||
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
|
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||||
self.storage[bot][chat_id][user_id].data = data.copy()
|
self.storage[key].data = data.copy()
|
||||||
|
|
||||||
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
|
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||||
return self.storage[bot][chat_id][user_id].data.copy()
|
return self.storage[key].data.copy()
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,67 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
|
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
|
||||||
|
|
||||||
from aioredis import ConnectionPool, Redis
|
from aioredis import ConnectionPool, Redis
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.fsm.state import State
|
from aiogram.dispatcher.fsm.state import State
|
||||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||||
|
|
||||||
PrefixFactoryType = Callable[[Bot], str]
|
|
||||||
STATE_KEY = "state"
|
|
||||||
STATE_DATA_KEY = "data"
|
|
||||||
STATE_LOCK_KEY = "lock"
|
|
||||||
|
|
||||||
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
|
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
|
||||||
|
|
||||||
|
|
||||||
|
class KeyBuilder(ABC):
|
||||||
|
"""
|
||||||
|
Base class for Redis key builder
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultKeyBuilder(KeyBuilder):
|
||||||
|
"""
|
||||||
|
Simple Redis key builder with default prefix.
|
||||||
|
|
||||||
|
Generates a colon-joined string with prefix, chat_id, user_id,
|
||||||
|
optional bot_id and optional destiny.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str = "fsm", with_bot_id: bool = False, with_destiny: bool = False
|
||||||
|
) -> None:
|
||||||
|
self.prefix = prefix
|
||||||
|
self.with_bot_id = with_bot_id
|
||||||
|
self.with_destiny = with_destiny
|
||||||
|
|
||||||
|
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||||
|
parts = [self.prefix]
|
||||||
|
if self.with_bot_id:
|
||||||
|
parts.append(str(key.bot_id))
|
||||||
|
parts.extend([str(key.chat_id), str(key.user_id)])
|
||||||
|
if self.with_destiny:
|
||||||
|
parts.append(key.destiny)
|
||||||
|
parts.append(part)
|
||||||
|
return ":".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class RedisStorage(BaseStorage):
|
class RedisStorage(BaseStorage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
prefix: str = "fsm",
|
key_builder: Optional[KeyBuilder] = None,
|
||||||
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
|
|
||||||
state_ttl: Optional[int] = None,
|
state_ttl: Optional[int] = None,
|
||||||
data_ttl: Optional[int] = None,
|
data_ttl: Optional[int] = None,
|
||||||
lock_kwargs: Optional[Dict[str, Any]] = None,
|
lock_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if key_builder is None:
|
||||||
|
key_builder = DefaultKeyBuilder()
|
||||||
if lock_kwargs is None:
|
if lock_kwargs is None:
|
||||||
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
|
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.prefix = prefix
|
self.key_builder = key_builder
|
||||||
self.prefix_bot = prefix_bot
|
|
||||||
self.state_ttl = state_ttl
|
self.state_ttl = state_ttl
|
||||||
self.data_ttl = data_ttl
|
self.data_ttl = data_ttl
|
||||||
self.lock_kwargs = lock_kwargs
|
self.lock_kwargs = lock_kwargs
|
||||||
|
|
@ -47,40 +79,28 @@ class RedisStorage(BaseStorage):
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.redis.close() # type: ignore
|
await self.redis.close() # type: ignore
|
||||||
|
|
||||||
def generate_key(self, bot: Bot, *parts: Any) -> str:
|
|
||||||
prefix_parts = [self.prefix]
|
|
||||||
if self.prefix_bot:
|
|
||||||
if isinstance(self.prefix_bot, dict):
|
|
||||||
prefix_parts.append(self.prefix_bot[bot.id])
|
|
||||||
elif callable(self.prefix_bot):
|
|
||||||
prefix_parts.append(self.prefix_bot(bot))
|
|
||||||
else:
|
|
||||||
prefix_parts.append(str(bot.id))
|
|
||||||
prefix_parts.extend(parts)
|
|
||||||
return ":".join(map(str, prefix_parts))
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lock(
|
async def lock(
|
||||||
self, bot: Bot, chat_id: int, user_id: int, state_lock_key: str = STATE_LOCK_KEY
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
key: StorageKey,
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
key = self.generate_key(bot, chat_id, user_id, state_lock_key)
|
redis_key = self.key_builder.build(key, "lock")
|
||||||
async with self.redis.lock(name=key, **self.lock_kwargs):
|
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
async def set_state(
|
async def set_state(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: int,
|
key: StorageKey,
|
||||||
user_id: int,
|
|
||||||
state: StateType = None,
|
state: StateType = None,
|
||||||
state_key: str = STATE_KEY,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
key = self.generate_key(bot, chat_id, user_id, state_key)
|
redis_key = self.key_builder.build(key, "state")
|
||||||
if state is None:
|
if state is None:
|
||||||
await self.redis.delete(key)
|
await self.redis.delete(redis_key)
|
||||||
else:
|
else:
|
||||||
await self.redis.set(
|
await self.redis.set(
|
||||||
key,
|
redis_key,
|
||||||
state.state if isinstance(state, State) else state, # type: ignore[arg-type]
|
state.state if isinstance(state, State) else state, # type: ignore[arg-type]
|
||||||
ex=self.state_ttl, # type: ignore[arg-type]
|
ex=self.state_ttl, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
@ -88,12 +108,10 @@ class RedisStorage(BaseStorage):
|
||||||
async def get_state(
|
async def get_state(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: int,
|
key: StorageKey,
|
||||||
user_id: int,
|
|
||||||
state_key: str = STATE_KEY,
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
key = self.generate_key(bot, chat_id, user_id, state_key)
|
redis_key = self.key_builder.build(key, "state")
|
||||||
value = await self.redis.get(key)
|
value = await self.redis.get(redis_key)
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
return value.decode("utf-8")
|
return value.decode("utf-8")
|
||||||
return cast(Optional[str], value)
|
return cast(Optional[str], value)
|
||||||
|
|
@ -101,27 +119,26 @@ class RedisStorage(BaseStorage):
|
||||||
async def set_data(
|
async def set_data(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: int,
|
key: StorageKey,
|
||||||
user_id: int,
|
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
state_data_key: str = STATE_DATA_KEY,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
key = self.generate_key(bot, chat_id, user_id, state_data_key)
|
redis_key = self.key_builder.build(key, "data")
|
||||||
if not data:
|
if not data:
|
||||||
await self.redis.delete(key)
|
await self.redis.delete(redis_key)
|
||||||
return
|
return
|
||||||
json_data = bot.session.json_dumps(data)
|
await self.redis.set(
|
||||||
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]
|
redis_key,
|
||||||
|
bot.session.json_dumps(data),
|
||||||
|
ex=self.data_ttl, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
async def get_data(
|
async def get_data(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: int,
|
key: StorageKey,
|
||||||
user_id: int,
|
|
||||||
state_data_key: str = STATE_DATA_KEY,
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
key = self.generate_key(bot, chat_id, user_id, state_data_key)
|
redis_key = self.key_builder.build(key, "data")
|
||||||
value = await self.redis.get(key)
|
value = await self.redis.get(redis_key)
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,33 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
|
from aiogram.dispatcher.fsm.storage.base import StorageKey
|
||||||
from tests.mocked_bot import MockedBot
|
from aiogram.dispatcher.fsm.storage.redis import DefaultKeyBuilder
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
PREFIX = "test"
|
||||||
|
BOT_ID = 42
|
||||||
|
CHAT_ID = -1
|
||||||
|
USER_ID = 2
|
||||||
|
FIELD = "data"
|
||||||
|
DESTINY = "testing"
|
||||||
|
|
||||||
@pytest.mark.redis
|
|
||||||
class TestRedisStorage:
|
class TestRedisDefaultKeyBuilder:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prefix_bot,result",
|
"with_bot_id,with_destiny,result",
|
||||||
[
|
[
|
||||||
[False, "fsm:-1:2"],
|
[False, False, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{FIELD}"],
|
||||||
[True, "fsm:42:-1:2"],
|
[True, False, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{FIELD}"],
|
||||||
[{42: "kaboom"}, "fsm:kaboom:-1:2"],
|
[True, True, f"{PREFIX}:{BOT_ID}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
|
||||||
[lambda bot: "kaboom", "fsm:kaboom:-1:2"],
|
[False, True, f"{PREFIX}:{CHAT_ID}:{USER_ID}:{DESTINY}:{FIELD}"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
|
async def test_generate_key(self, with_bot_id: bool, with_destiny: bool, result: str):
|
||||||
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
|
key_builder = DefaultKeyBuilder(
|
||||||
assert storage.generate_key(bot, -1, 2) == result
|
prefix=PREFIX,
|
||||||
|
with_bot_id=with_bot_id,
|
||||||
|
with_destiny=with_destiny,
|
||||||
|
)
|
||||||
|
key = StorageKey(chat_id=CHAT_ID, user_id=USER_ID, bot_id=BOT_ID, destiny=DESTINY)
|
||||||
|
assert key_builder.build(key, FIELD) == result
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,54 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage
|
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StorageKey
|
||||||
from tests.mocked_bot import MockedBot
|
from tests.mocked_bot import MockedBot
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="storage_key")
|
||||||
|
def create_storate_key(bot: MockedBot):
|
||||||
|
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"storage",
|
"storage",
|
||||||
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
|
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
|
||||||
)
|
)
|
||||||
class TestStorages:
|
class TestStorages:
|
||||||
async def test_lock(self, bot: MockedBot, storage: BaseStorage):
|
async def test_lock(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
|
||||||
# TODO: ?!?
|
# TODO: ?!?
|
||||||
async with storage.lock(bot=bot, chat_id=-42, user_id=42):
|
async with storage.lock(bot=bot, key=storage_key):
|
||||||
assert True, "You are kidding me?"
|
assert True, "You are kidding me?"
|
||||||
|
|
||||||
async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
|
async def test_set_state(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
|
||||||
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
|
assert await storage.get_state(bot=bot, key=storage_key) is None
|
||||||
|
|
||||||
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
|
await storage.set_state(bot=bot, key=storage_key, state="state")
|
||||||
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
|
assert await storage.get_state(bot=bot, key=storage_key) == "state"
|
||||||
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
|
await storage.set_state(bot=bot, key=storage_key, state=None)
|
||||||
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
|
assert await storage.get_state(bot=bot, key=storage_key) is None
|
||||||
|
|
||||||
async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
|
async def test_set_data(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
|
||||||
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
|
assert await storage.get_data(bot=bot, key=storage_key) == {}
|
||||||
|
|
||||||
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
|
await storage.set_data(bot=bot, key=storage_key, data={"foo": "bar"})
|
||||||
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
|
assert await storage.get_data(bot=bot, key=storage_key) == {"foo": "bar"}
|
||||||
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
|
await storage.set_data(bot=bot, key=storage_key, data={})
|
||||||
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
|
assert await storage.get_data(bot=bot, key=storage_key) == {}
|
||||||
|
|
||||||
async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
|
async def test_update_data(
|
||||||
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
|
self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey
|
||||||
assert await storage.update_data(
|
):
|
||||||
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
|
assert await storage.get_data(bot=bot, key=storage_key) == {}
|
||||||
) == {"foo": "bar"}
|
assert await storage.update_data(bot=bot, key=storage_key, data={"foo": "bar"}) == {
|
||||||
assert await storage.update_data(
|
"foo": "bar"
|
||||||
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
|
}
|
||||||
) == {"foo": "bar", "baz": "spam"}
|
assert await storage.update_data(bot=bot, key=storage_key, data={"baz": "spam"}) == {
|
||||||
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {
|
"foo": "bar",
|
||||||
|
"baz": "spam",
|
||||||
|
}
|
||||||
|
assert await storage.get_data(bot=bot, key=storage_key) == {
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
"baz": "spam",
|
"baz": "spam",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiogram.dispatcher.fsm.context import FSMContext
|
from aiogram.dispatcher.fsm.context import FSMContext
|
||||||
|
from aiogram.dispatcher.fsm.storage.base import StorageKey
|
||||||
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
|
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
|
||||||
from tests.mocked_bot import MockedBot
|
from tests.mocked_bot import MockedBot
|
||||||
|
|
||||||
|
|
@ -10,21 +11,28 @@ pytestmark = pytest.mark.asyncio
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def state(bot: MockedBot):
|
def state(bot: MockedBot):
|
||||||
storage = MemoryStorage()
|
storage = MemoryStorage()
|
||||||
ctx = storage.storage[bot][-42][42]
|
key = StorageKey(user_id=42, chat_id=-42, bot_id=bot.id)
|
||||||
|
ctx = storage.storage[key]
|
||||||
ctx.state = "test"
|
ctx.state = "test"
|
||||||
ctx.data = {"foo": "bar"}
|
ctx.data = {"foo": "bar"}
|
||||||
return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
|
return FSMContext(bot=bot, storage=storage, key=key)
|
||||||
|
|
||||||
|
|
||||||
class TestFSMContext:
|
class TestFSMContext:
|
||||||
async def test_address_mapping(self, bot: MockedBot):
|
async def test_address_mapping(self, bot: MockedBot):
|
||||||
storage = MemoryStorage()
|
storage = MemoryStorage()
|
||||||
ctx = storage.storage[bot][-42][42]
|
ctx = storage.storage[StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)]
|
||||||
ctx.state = "test"
|
ctx.state = "test"
|
||||||
ctx.data = {"foo": "bar"}
|
ctx.data = {"foo": "bar"}
|
||||||
state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
|
state = FSMContext(
|
||||||
state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
|
bot=bot, storage=storage, key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
|
||||||
state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
|
)
|
||||||
|
state2 = FSMContext(
|
||||||
|
bot=bot, storage=storage, key=StorageKey(chat_id=42, user_id=42, bot_id=bot.id)
|
||||||
|
)
|
||||||
|
state3 = FSMContext(
|
||||||
|
bot=bot, storage=storage, key=StorageKey(chat_id=69, user_id=69, bot_id=bot.id)
|
||||||
|
)
|
||||||
|
|
||||||
assert await state.get_state() == "test"
|
assert await state.get_state() == "test"
|
||||||
assert await state2.get_state() is None
|
assert await state2.get_state() is None
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import pytest
|
||||||
|
|
||||||
from aiogram import Dispatcher
|
from aiogram import Dispatcher
|
||||||
from aiogram.dispatcher.fsm.context import FSMContext
|
from aiogram.dispatcher.fsm.context import FSMContext
|
||||||
|
from aiogram.dispatcher.fsm.storage.base import StorageKey
|
||||||
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
|
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
|
||||||
from aiogram.types import Update, User
|
from aiogram.types import Update, User
|
||||||
from aiogram.utils.i18n import ConstI18nMiddleware, FSMI18nMiddleware, I18n, SimpleI18nMiddleware
|
from aiogram.utils.i18n import ConstI18nMiddleware, FSMI18nMiddleware, I18n, SimpleI18nMiddleware
|
||||||
|
|
@ -131,7 +132,9 @@ class TestFSMI18nMiddleware:
|
||||||
async def test_middleware(self, i18n: I18n, bot: MockedBot):
|
async def test_middleware(self, i18n: I18n, bot: MockedBot):
|
||||||
middleware = FSMI18nMiddleware(i18n=i18n)
|
middleware = FSMI18nMiddleware(i18n=i18n)
|
||||||
storage = MemoryStorage()
|
storage = MemoryStorage()
|
||||||
state = FSMContext(bot=bot, storage=storage, user_id=42, chat_id=42)
|
state = FSMContext(
|
||||||
|
bot=bot, storage=storage, key=StorageKey(user_id=42, chat_id=42, bot_id=bot.id)
|
||||||
|
)
|
||||||
data = {
|
data = {
|
||||||
"event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"),
|
"event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"),
|
||||||
"state": state,
|
"state": state,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue