storage cleanup (#1144)

* storage cleanup

* storage cleanup
This commit is contained in:
RootShinobi 2023-04-08 18:01:11 +03:00 committed by GitHub
parent d8a977f357
commit dbaf6fabcb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 54 additions and 75 deletions

1
CHANGES/1144.misc.rst Normal file
View file

@ -0,0 +1 @@
Removed bot parameters from storages

View file

@ -1,33 +1,31 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
class FSMContext: class FSMContext:
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None: def __init__(self, storage: BaseStorage, key: StorageKey) -> None:
self.bot = bot
self.storage = storage self.storage = storage
self.key = key self.key = key
async def set_state(self, state: StateType = None) -> None: async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(bot=self.bot, key=self.key, state=state) await self.storage.set_state(key=self.key, state=state)
async def get_state(self) -> Optional[str]: async def get_state(self) -> Optional[str]:
return await self.storage.get_state(bot=self.bot, key=self.key) return await self.storage.get_state(key=self.key)
async def set_data(self, data: Dict[str, Any]) -> None: async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(bot=self.bot, key=self.key, data=data) await self.storage.set_data(key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]: async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(bot=self.bot, key=self.key) return await self.storage.get_data(key=self.key)
async def update_data( async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if data: if data:
kwargs.update(data) kwargs.update(data)
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs) return await self.storage.update_data(key=self.key, data=kwargs)
async def clear(self) -> None: async def clear(self) -> None:
await self.set_state(state=None) await self.set_state(state=None)

View file

@ -35,7 +35,7 @@ class FSMContextMiddleware(BaseMiddleware):
data["fsm_storage"] = self.storage data["fsm_storage"] = self.storage
if context: if context:
data.update({"state": context, "raw_state": await context.get_state()}) data.update({"state": context, "raw_state": await context.get_state()})
async with self.events_isolation.lock(bot=bot, key=context.key): async with self.events_isolation.lock(key=context.key):
return await handler(event, data) return await handler(event, data)
return await handler(event, data) return await handler(event, data)
@ -76,7 +76,6 @@ class FSMContextMiddleware(BaseMiddleware):
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> FSMContext: ) -> FSMContext:
return FSMContext( return FSMContext(
bot=bot,
storage=self.storage, storage=self.storage,
key=StorageKey( key=StorageKey(
user_id=user_id, user_id=user_id,

View file

@ -3,7 +3,6 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
from aiogram.fsm.state import State from aiogram.fsm.state import State
StateType = Optional[Union[str, State]] StateType = Optional[Union[str, State]]
@ -25,61 +24,56 @@ class BaseStorage(ABC):
""" """
@abstractmethod @abstractmethod
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None: async def set_state(self, key: StorageKey, state: StateType = None) -> None:
""" """
Set state for specified key Set state for specified key
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:param state: new state :param state: new state
""" """
pass pass
@abstractmethod @abstractmethod
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> Optional[str]:
""" """
Get key state Get key state
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:return: current state :return: current state
""" """
pass pass
@abstractmethod @abstractmethod
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
""" """
Write data (replace) Write data (replace)
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:param data: new data :param data: new data
""" """
pass pass
@abstractmethod @abstractmethod
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> Dict[str, Any]:
""" """
Get current data for key Get current data for key
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:return: current data :return: current data
""" """
pass pass
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]: async def update_data(self, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
""" """
Update date in the storage for key (like dict.update) Update date in the storage for key (like dict.update)
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:param data: partial data :param data: partial data
:return: new data :return: new data
""" """
current_data = await self.get_data(bot=bot, key=key) current_data = await self.get_data(key=key)
current_data.update(data) current_data.update(data)
await self.set_data(bot=bot, key=key, data=current_data) await self.set_data(key=key, data=current_data)
return current_data.copy() return current_data.copy()
@abstractmethod @abstractmethod
@ -93,12 +87,11 @@ class BaseStorage(ABC):
class BaseEventIsolation(ABC): class BaseEventIsolation(ABC):
@abstractmethod @abstractmethod
@asynccontextmanager @asynccontextmanager
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
""" """
Isolate events with lock. Isolate events with lock.
Will be used as context manager Will be used as context manager
:param bot: instance of the current bot
:param key: storage key :param key: storage key
:return: An async generator :return: An async generator
""" """

View file

@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional
from aiogram import Bot
from aiogram.fsm.state import State from aiogram.fsm.state import State
from aiogram.fsm.storage.base import ( from aiogram.fsm.storage.base import (
BaseEventIsolation, BaseEventIsolation,
@ -38,22 +37,22 @@ class MemoryStorage(BaseStorage):
async def close(self) -> None: async def close(self) -> None:
pass pass
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None: async def set_state(self, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> Optional[str]:
return self.storage[key].state return self.storage[key].state
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
self.storage[key].data = data.copy() self.storage[key].data = data.copy()
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> Dict[str, Any]:
return self.storage[key].data.copy() return self.storage[key].data.copy()
class DisabledEventIsolation(BaseEventIsolation): class DisabledEventIsolation(BaseEventIsolation):
@asynccontextmanager @asynccontextmanager
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
yield yield
async def close(self) -> None: async def close(self) -> None:
@ -66,7 +65,7 @@ class SimpleEventIsolation(BaseEventIsolation):
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock) self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
@asynccontextmanager @asynccontextmanager
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
lock = self._locks[key] lock = self._locks[key]
async with lock: async with lock:
yield yield

View file

@ -1,13 +1,13 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast from typing import Any, AsyncGenerator, Callable, Dict, Literal, Optional, cast
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock from redis.asyncio.lock import Lock
from redis.typing import ExpiryT from redis.typing import ExpiryT
from aiogram import Bot
from aiogram.fsm.state import State from aiogram.fsm.state import State
from aiogram.fsm.storage.base import ( from aiogram.fsm.storage.base import (
DEFAULT_DESTINY, DEFAULT_DESTINY,
@ -18,6 +18,8 @@ from aiogram.fsm.storage.base import (
) )
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60} DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
class KeyBuilder(ABC): class KeyBuilder(ABC):
@ -93,13 +95,14 @@ class RedisStorage(BaseStorage):
key_builder: Optional[KeyBuilder] = None, key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[ExpiryT] = None, state_ttl: Optional[ExpiryT] = None,
data_ttl: Optional[ExpiryT] = None, data_ttl: Optional[ExpiryT] = None,
json_loads: _JsonLoads = json.loads,
json_dumps: _JsonDumps = json.dumps,
) -> None: ) -> None:
""" """
:param redis: Instance of Redis connection :param redis: Instance of Redis connection
:param key_builder: builder that helps to convert contextual key to string :param key_builder: builder that helps to convert contextual key to string
:param state_ttl: TTL for state records :param state_ttl: TTL for state records
:param data_ttl: TTL for data records :param data_ttl: TTL for data records
:param lock_kwargs: Custom arguments for Redis lock
""" """
if key_builder is None: if key_builder is None:
key_builder = DefaultKeyBuilder() key_builder = DefaultKeyBuilder()
@ -107,6 +110,8 @@ class RedisStorage(BaseStorage):
self.key_builder = key_builder self.key_builder = key_builder
self.state_ttl = state_ttl self.state_ttl = state_ttl
self.data_ttl = data_ttl self.data_ttl = data_ttl
self.json_loads = json_loads
self.json_dumps = json_dumps
@classmethod @classmethod
def from_url( def from_url(
@ -134,7 +139,6 @@ class RedisStorage(BaseStorage):
async def set_state( async def set_state(
self, self,
bot: Bot,
key: StorageKey, key: StorageKey,
state: StateType = None, state: StateType = None,
) -> None: ) -> None:
@ -150,7 +154,6 @@ class RedisStorage(BaseStorage):
async def get_state( async def get_state(
self, self,
bot: Bot,
key: StorageKey, key: StorageKey,
) -> Optional[str]: ) -> Optional[str]:
redis_key = self.key_builder.build(key, "state") redis_key = self.key_builder.build(key, "state")
@ -161,7 +164,6 @@ class RedisStorage(BaseStorage):
async def set_data( async def set_data(
self, self,
bot: Bot,
key: StorageKey, key: StorageKey,
data: Dict[str, Any], data: Dict[str, Any],
) -> None: ) -> None:
@ -171,13 +173,12 @@ class RedisStorage(BaseStorage):
return return
await self.redis.set( await self.redis.set(
redis_key, redis_key,
bot.session.json_dumps(data), self.json_dumps(data),
ex=self.data_ttl, ex=self.data_ttl,
) )
async def get_data( async def get_data(
self, self,
bot: Bot,
key: StorageKey, key: StorageKey,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
redis_key = self.key_builder.build(key, "data") redis_key = self.key_builder.build(key, "data")
@ -186,7 +187,7 @@ class RedisStorage(BaseStorage):
return {} return {}
if isinstance(value, bytes): if isinstance(value, bytes):
value = value.decode("utf-8") value = value.decode("utf-8")
return cast(Dict[str, Any], bot.session.json_loads(value)) return cast(Dict[str, Any], self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation): class RedisEventIsolation(BaseEventIsolation):
@ -220,7 +221,6 @@ class RedisEventIsolation(BaseEventIsolation):
@asynccontextmanager @asynccontextmanager
async def lock( async def lock(
self, self,
bot: Bot,
key: StorageKey, key: StorageKey,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
redis_key = self.key_builder.build(key, "lock") redis_key = self.key_builder.build(key, "lock")

View file

@ -5,7 +5,7 @@ from tests.mocked_bot import MockedBot
@pytest.fixture(name="storage_key") @pytest.fixture(name="storage_key")
def create_storate_key(bot: MockedBot): def create_storage_key(bot: MockedBot):
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
@ -20,9 +20,8 @@ def create_storate_key(bot: MockedBot):
class TestIsolations: class TestIsolations:
async def test_lock( async def test_lock(
self, self,
bot: MockedBot,
isolation: BaseEventIsolation, isolation: BaseEventIsolation,
storage_key: StorageKey, storage_key: StorageKey,
): ):
async with isolation.lock(bot=bot, key=storage_key): async with isolation.lock(key=storage_key):
assert True, "You are kidding me?" assert True, "You are kidding me?"

View file

@ -5,7 +5,7 @@ from tests.mocked_bot import MockedBot
@pytest.fixture(name="storage_key") @pytest.fixture(name="storage_key")
def create_storate_key(bot: MockedBot): def create_storage_key(bot: MockedBot):
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
@ -15,33 +15,31 @@ def create_storate_key(bot: MockedBot):
) )
class TestStorages: class TestStorages:
async def test_set_state(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey): async def test_set_state(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_state(bot=bot, key=storage_key) is None assert await storage.get_state(key=storage_key) is None
await storage.set_state(bot=bot, key=storage_key, state="state") await storage.set_state(key=storage_key, state="state")
assert await storage.get_state(bot=bot, key=storage_key) == "state" assert await storage.get_state(key=storage_key) == "state"
await storage.set_state(bot=bot, key=storage_key, state=None) await storage.set_state(key=storage_key, state=None)
assert await storage.get_state(bot=bot, key=storage_key) is None assert await storage.get_state(key=storage_key) is None
async def test_set_data(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey): async def test_set_data(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):
assert await storage.get_data(bot=bot, key=storage_key) == {} assert await storage.get_data(key=storage_key) == {}
await storage.set_data(bot=bot, key=storage_key, data={"foo": "bar"}) await storage.set_data(key=storage_key, data={"foo": "bar"})
assert await storage.get_data(bot=bot, key=storage_key) == {"foo": "bar"} assert await storage.get_data(key=storage_key) == {"foo": "bar"}
await storage.set_data(bot=bot, key=storage_key, data={}) await storage.set_data(key=storage_key, data={})
assert await storage.get_data(bot=bot, key=storage_key) == {} assert await storage.get_data(key=storage_key) == {}
async def test_update_data( async def test_update_data(
self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey
): ):
assert await storage.get_data(bot=bot, key=storage_key) == {} assert await storage.get_data(key=storage_key) == {}
assert await storage.update_data(bot=bot, key=storage_key, data={"foo": "bar"}) == { assert await storage.update_data(key=storage_key, data={"foo": "bar"}) == {"foo": "bar"}
"foo": "bar" assert await storage.update_data(key=storage_key, data={"baz": "spam"}) == {
}
assert await storage.update_data(bot=bot, key=storage_key, data={"baz": "spam"}) == {
"foo": "bar", "foo": "bar",
"baz": "spam", "baz": "spam",
} }
assert await storage.get_data(bot=bot, key=storage_key) == { assert await storage.get_data(key=storage_key) == {
"foo": "bar", "foo": "bar",
"baz": "spam", "baz": "spam",
} }

View file

@ -13,7 +13,7 @@ def state(bot: MockedBot):
ctx = storage.storage[key] ctx = storage.storage[key]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
return FSMContext(bot=bot, storage=storage, key=key) return FSMContext(storage=storage, key=key)
class TestFSMContext: class TestFSMContext:
@ -22,15 +22,9 @@ class TestFSMContext:
ctx = storage.storage[StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)] ctx = storage.storage[StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
state = FSMContext( state = FSMContext(storage=storage, key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id))
bot=bot, storage=storage, key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) state2 = FSMContext(storage=storage, key=StorageKey(chat_id=42, user_id=42, bot_id=bot.id))
) state3 = FSMContext(storage=storage, key=StorageKey(chat_id=69, user_id=69, bot_id=bot.id))
state2 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=42, user_id=42, bot_id=bot.id)
)
state3 = FSMContext(
bot=bot, storage=storage, key=StorageKey(chat_id=69, user_id=69, bot_id=bot.id)
)
assert await state.get_state() == "test" assert await state.get_state() == "test"
assert await state2.get_state() is None assert await state2.get_state() is None

View file

@ -178,9 +178,7 @@ class TestFSMI18nMiddleware:
async def test_middleware(self, i18n: I18n, bot: MockedBot, extra): async def test_middleware(self, i18n: I18n, bot: MockedBot, extra):
middleware = FSMI18nMiddleware(i18n=i18n) middleware = FSMI18nMiddleware(i18n=i18n)
storage = MemoryStorage() storage = MemoryStorage()
state = FSMContext( state = FSMContext(storage=storage, key=StorageKey(user_id=42, chat_id=42, bot_id=bot.id))
bot=bot, storage=storage, key=StorageKey(user_id=42, chat_id=42, bot_id=bot.id)
)
data = { data = {
"event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"), "event_from_user": User(id=42, is_bot=False, language_code="it", first_name="Test"),
"state": state, "state": state,