mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 07:50:32 +00:00
Migrate from aioredis to redis.asyncio (#1074)
* chore: migrate from aioredis to redis.asyncio * chore: add tests for RedisStorage2
This commit is contained in:
parent
87c0458d95
commit
ae534298e5
3 changed files with 195 additions and 288 deletions
|
|
@ -1,18 +1,18 @@
|
|||
"""
|
||||
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver
|
||||
This module has redis storage for finite-state machine based on `redis <https://pypi.org/project/redis/>`_ driver.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import aioredis
|
||||
|
||||
from ...dispatcher.storage import BaseStorage
|
||||
from ...utils import json
|
||||
from ...utils.deprecated import deprecated
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import aioredis
|
||||
|
||||
STATE_KEY = 'state'
|
||||
STATE_DATA_KEY = 'data'
|
||||
STATE_BUCKET_KEY = 'bucket'
|
||||
|
|
@ -67,6 +67,8 @@ class RedisStorage(BaseStorage):
|
|||
Get Redis connection
|
||||
"""
|
||||
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||
import aioredis
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._redis is None or self._redis.closed:
|
||||
self._redis = await aioredis.create_connection((self._host, self._port),
|
||||
|
|
@ -207,138 +209,6 @@ class RedisStorage(BaseStorage):
|
|||
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
|
||||
|
||||
|
||||
class AioRedisAdapterBase(ABC):
|
||||
"""Base aioredis adapter class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: typing.Optional[int] = None,
|
||||
password: typing.Optional[str] = None,
|
||||
ssl: typing.Optional[bool] = None,
|
||||
pool_size: int = 10,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
prefix: str = "fsm",
|
||||
state_ttl: typing.Optional[int] = None,
|
||||
data_ttl: typing.Optional[int] = None,
|
||||
bucket_ttl: typing.Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._password = password
|
||||
self._ssl = ssl
|
||||
self._pool_size = pool_size
|
||||
self._kwargs = kwargs
|
||||
self._prefix = (prefix,)
|
||||
|
||||
self._state_ttl = state_ttl
|
||||
self._data_ttl = data_ttl
|
||||
self._bucket_ttl = bucket_ttl
|
||||
|
||||
self._redis: typing.Optional["aioredis.Redis"] = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
|
||||
@abstractmethod
|
||||
async def get_redis(self) -> aioredis.Redis:
|
||||
"""Get Redis connection."""
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Grace shutdown."""
|
||||
pass
|
||||
|
||||
async def wait_closed(self):
|
||||
"""Wait for grace shutdown finishes."""
|
||||
pass
|
||||
|
||||
async def set(self, name, value, ex=None, **kwargs):
|
||||
"""Set the value at key ``name`` to ``value``."""
|
||||
if ex == 0:
|
||||
ex = None
|
||||
return await self._redis.set(name, value, ex=ex, **kwargs)
|
||||
|
||||
async def get(self, name, **kwargs):
|
||||
"""Return the value at key ``name`` or None."""
|
||||
return await self._redis.get(name, **kwargs)
|
||||
|
||||
async def delete(self, *names):
|
||||
"""Delete one or more keys specified by ``names``"""
|
||||
return await self._redis.delete(*names)
|
||||
|
||||
async def keys(self, pattern, **kwargs):
|
||||
"""Returns a list of keys matching ``pattern``."""
|
||||
return await self._redis.keys(pattern, **kwargs)
|
||||
|
||||
async def flushdb(self):
|
||||
"""Delete all keys in the current database."""
|
||||
return await self._redis.flushdb()
|
||||
|
||||
|
||||
class AioRedisAdapterV1(AioRedisAdapterBase):
|
||||
"""Redis adapter for aioredis v1."""
|
||||
|
||||
async def get_redis(self) -> aioredis.Redis:
|
||||
"""Get Redis connection."""
|
||||
async with self._connection_lock: # to prevent race
|
||||
if self._redis is None or self._redis.closed:
|
||||
self._redis = await aioredis.create_redis_pool(
|
||||
(self._host, self._port),
|
||||
db=self._db,
|
||||
password=self._password,
|
||||
ssl=self._ssl,
|
||||
minsize=1,
|
||||
maxsize=self._pool_size,
|
||||
**self._kwargs,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def close(self):
|
||||
async with self._connection_lock:
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
|
||||
async def wait_closed(self):
|
||||
async with self._connection_lock:
|
||||
if self._redis:
|
||||
return await self._redis.wait_closed()
|
||||
return True
|
||||
|
||||
async def get(self, name, **kwargs):
|
||||
return await self._redis.get(name, encoding="utf8", **kwargs)
|
||||
|
||||
async def set(self, name, value, ex=None, **kwargs):
|
||||
if ex == 0:
|
||||
ex = None
|
||||
return await self._redis.set(name, value, expire=ex, **kwargs)
|
||||
|
||||
async def keys(self, pattern, **kwargs):
|
||||
"""Returns a list of keys matching ``pattern``."""
|
||||
return await self._redis.keys(pattern, encoding="utf8", **kwargs)
|
||||
|
||||
|
||||
class AioRedisAdapterV2(AioRedisAdapterBase):
|
||||
"""Redis adapter for aioredis v2."""
|
||||
|
||||
async def get_redis(self) -> aioredis.Redis:
|
||||
"""Get Redis connection."""
|
||||
async with self._connection_lock: # to prevent race
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.Redis(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
db=self._db,
|
||||
password=self._password,
|
||||
ssl=self._ssl,
|
||||
max_connections=self._pool_size,
|
||||
decode_responses=True,
|
||||
**self._kwargs,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
|
||||
class RedisStorage2(BaseStorage):
|
||||
"""
|
||||
Busted Redis-base storage for FSM.
|
||||
|
|
@ -356,7 +226,6 @@ class RedisStorage2(BaseStorage):
|
|||
.. code-block:: python3
|
||||
|
||||
await dp.storage.close()
|
||||
await dp.storage.wait_closed()
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -375,75 +244,49 @@ class RedisStorage2(BaseStorage):
|
|||
bucket_ttl: typing.Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._password = password
|
||||
self._ssl = ssl
|
||||
self._pool_size = pool_size
|
||||
self._kwargs = kwargs
|
||||
self._prefix = (prefix,)
|
||||
from redis.asyncio import Redis
|
||||
|
||||
self._redis: typing.Optional[Redis] = Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
ssl=ssl,
|
||||
max_connections=pool_size,
|
||||
decode_responses=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._prefix = (prefix,)
|
||||
self._state_ttl = state_ttl
|
||||
self._data_ttl = data_ttl
|
||||
self._bucket_ttl = bucket_ttl
|
||||
|
||||
self._redis: typing.Optional[AioRedisAdapterBase] = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
|
||||
@deprecated("This method will be removed in aiogram v3.0. "
|
||||
"You should use your own instance of Redis.", stacklevel=3)
|
||||
async def redis(self) -> aioredis.Redis:
|
||||
adapter = await self._get_adapter()
|
||||
return await adapter.get_redis()
|
||||
|
||||
async def _get_adapter(self) -> AioRedisAdapterBase:
|
||||
"""Get adapter based on aioredis version."""
|
||||
if self._redis is None:
|
||||
redis_version = int(aioredis.__version__.split(".")[0])
|
||||
connection_data = dict(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
db=self._db,
|
||||
password=self._password,
|
||||
ssl=self._ssl,
|
||||
pool_size=self._pool_size,
|
||||
**self._kwargs,
|
||||
)
|
||||
if redis_version == 1:
|
||||
self._redis = AioRedisAdapterV1(**connection_data)
|
||||
elif redis_version == 2:
|
||||
self._redis = AioRedisAdapterV2(**connection_data)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported aioredis version: {redis_version}")
|
||||
await self._redis.get_redis()
|
||||
async def redis(self) -> "aioredis.Redis":
|
||||
return self._redis
|
||||
|
||||
def generate_key(self, *parts):
|
||||
return ':'.join(self._prefix + tuple(map(str, parts)))
|
||||
|
||||
async def close(self):
|
||||
if self._redis:
|
||||
return await self._redis.close()
|
||||
await self._redis.close()
|
||||
|
||||
async def wait_closed(self):
|
||||
if self._redis:
|
||||
await self._redis.wait_closed()
|
||||
self._redis = None
|
||||
pass
|
||||
|
||||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_KEY)
|
||||
redis = await self._get_adapter()
|
||||
return await redis.get(key) or self.resolve_state(default)
|
||||
return await self._redis.get(key) or self.resolve_state(default)
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||
redis = await self._get_adapter()
|
||||
raw_result = await redis.get(key)
|
||||
raw_result = await self._redis.get(key)
|
||||
if raw_result:
|
||||
return json.loads(raw_result)
|
||||
return default or {}
|
||||
|
|
@ -452,21 +295,19 @@ class RedisStorage2(BaseStorage):
|
|||
state: typing.Optional[typing.AnyStr] = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_KEY)
|
||||
redis = await self._get_adapter()
|
||||
if state is None:
|
||||
await redis.delete(key)
|
||||
await self._redis.delete(key)
|
||||
else:
|
||||
await redis.set(key, self.resolve_state(state), ex=self._state_ttl)
|
||||
await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||
redis = await self._get_adapter()
|
||||
if data:
|
||||
await redis.set(key, json.dumps(data), ex=self._data_ttl)
|
||||
await self._redis.set(key, json.dumps(data), ex=self._data_ttl)
|
||||
else:
|
||||
await redis.delete(key)
|
||||
await self._redis.delete(key)
|
||||
|
||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None, **kwargs):
|
||||
|
|
@ -483,8 +324,7 @@ class RedisStorage2(BaseStorage):
|
|||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||
redis = await self._get_adapter()
|
||||
raw_result = await redis.get(key)
|
||||
raw_result = await self._redis.get(key)
|
||||
if raw_result:
|
||||
return json.loads(raw_result)
|
||||
return default or {}
|
||||
|
|
@ -493,11 +333,10 @@ class RedisStorage2(BaseStorage):
|
|||
bucket: typing.Dict = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||
redis = await self._get_adapter()
|
||||
if bucket:
|
||||
await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
|
||||
await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
|
||||
else:
|
||||
await redis.delete(key)
|
||||
await self._redis.delete(key)
|
||||
|
||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
|
|
@ -515,13 +354,11 @@ class RedisStorage2(BaseStorage):
|
|||
:param full: clean DB or clean only states
|
||||
:return:
|
||||
"""
|
||||
redis = await self._get_adapter()
|
||||
|
||||
if full:
|
||||
await redis.flushdb()
|
||||
await self._redis.flushdb()
|
||||
else:
|
||||
keys = await redis.keys(self.generate_key('*'))
|
||||
await redis.delete(*keys)
|
||||
keys = await self._redis.keys(self.generate_key('*'))
|
||||
await self._redis.delete(*keys)
|
||||
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
"""
|
||||
|
|
@ -529,10 +366,9 @@ class RedisStorage2(BaseStorage):
|
|||
|
||||
:return: list of tuples where first element is chat id and second is user id
|
||||
"""
|
||||
redis = await self._get_adapter()
|
||||
result = []
|
||||
|
||||
keys = await redis.keys(self.generate_key('*', '*', STATE_KEY))
|
||||
keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY))
|
||||
for item in keys:
|
||||
*_, chat, user, _ = item.split(':')
|
||||
result.append((chat, user))
|
||||
|
|
|
|||
|
|
@ -1,89 +0,0 @@
|
|||
import aioredis
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
@pytest.mark.redis
|
||||
async def redis_store(redis_options):
|
||||
if int(aioredis.__version__.split(".")[0]) == 2:
|
||||
pytest.skip('aioredis v2 is not supported.')
|
||||
return
|
||||
s = RedisStorage(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
conn = await s.redis()
|
||||
await conn.execute('FLUSHDB')
|
||||
await s.close()
|
||||
await s.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
@pytest.mark.redis
|
||||
async def redis_store2(redis_options):
|
||||
s = RedisStorage2(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
conn = await s.redis()
|
||||
await conn.flushdb()
|
||||
await s.close()
|
||||
await s.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def memory_store():
|
||||
yield MemoryStorage()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
lazy_fixture('memory_store'),
|
||||
]
|
||||
)
|
||||
class TestStorage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_get(self, store):
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset(self, store):
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
await store.reset_data(chat='1234')
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_empty(self, store):
|
||||
await store.reset_data(chat='1234')
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
]
|
||||
)
|
||||
class TestRedisStorage2:
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_and_open_connection(self, store):
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
pool_id = id(store._redis)
|
||||
await store.close()
|
||||
await store.wait_closed()
|
||||
|
||||
# new pool will be open at this point
|
||||
assert await store.get_data(chat='1234') == {
|
||||
'foo': 'bar',
|
||||
}
|
||||
assert id(store._redis) != pool_id
|
||||
160
tests/test_contrib/test_fsm_storage/test_storage.py
Normal file
160
tests/test_contrib/test_fsm_storage/test_storage.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
import aioredis
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
from redis.asyncio.connection import Connection, ConnectionPool
|
||||
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2
|
||||
from aiogram.types import Chat, User
|
||||
from tests.types.dataset import CHAT, USER
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
@pytest.mark.redis
|
||||
async def redis_store(redis_options):
|
||||
if int(aioredis.__version__.split(".")[0]) == 2:
|
||||
pytest.skip('aioredis v2 is not supported.')
|
||||
return
|
||||
s = RedisStorage(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
conn = await s.redis()
|
||||
await conn.execute('FLUSHDB')
|
||||
await s.close()
|
||||
await s.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
@pytest.mark.redis
|
||||
async def redis_store2(redis_options):
|
||||
s = RedisStorage2(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
conn = await s.redis()
|
||||
await conn.flushdb()
|
||||
await s.close()
|
||||
await s.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def memory_store():
|
||||
yield MemoryStorage()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
lazy_fixture('memory_store'),
|
||||
]
|
||||
)
|
||||
class TestStorage:
|
||||
async def test_set_get(self, store):
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
|
||||
async def test_reset(self, store):
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
await store.reset_data(chat='1234')
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
|
||||
async def test_reset_empty(self, store):
|
||||
await store.reset_data(chat='1234')
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
]
|
||||
)
|
||||
class TestRedisStorage2:
|
||||
async def test_close_and_open_connection(self, store: RedisStorage2):
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
await store.close()
|
||||
await store.wait_closed()
|
||||
|
||||
pool: ConnectionPool = store._redis.connection_pool
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
assert not pool._in_use_connections
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
if pool._available_connections:
|
||||
# noinspection PyUnresolvedReferences
|
||||
connection: Connection = pool._available_connections[0]
|
||||
assert connection.is_connected is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chat_id,user_id,state",
|
||||
[
|
||||
[12345, 54321, "foo"],
|
||||
[12345, 54321, None],
|
||||
[12345, None, "foo"],
|
||||
[None, 54321, "foo"],
|
||||
],
|
||||
)
|
||||
async def test_set_get_state(self, chat_id, user_id, state, store):
|
||||
await store.reset_state(chat=chat_id, user=user_id, with_data=False)
|
||||
|
||||
await store.set_state(chat=chat_id, user=user_id, state=state)
|
||||
s = await store.get_state(chat=chat_id, user=user_id)
|
||||
assert s == state
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chat_id,user_id,data,new_data",
|
||||
[
|
||||
[12345, 54321, {"foo": "bar"}, {"bar": "foo"}],
|
||||
[12345, 54321, None, None],
|
||||
[12345, 54321, {"foo": "bar"}, None],
|
||||
[12345, 54321, None, {"bar": "foo"}],
|
||||
[12345, None, {"foo": "bar"}, {"bar": "foo"}],
|
||||
[None, 54321, {"foo": "bar"}, {"bar": "foo"}],
|
||||
],
|
||||
)
|
||||
async def test_set_get_update_data(self, chat_id, user_id, data, new_data, store):
|
||||
await store.reset_state(chat=chat_id, user=user_id, with_data=True)
|
||||
|
||||
await store.set_data(chat=chat_id, user=user_id, data=data)
|
||||
d = await store.get_data(chat=chat_id, user=user_id)
|
||||
assert d == (data or {})
|
||||
|
||||
await store.update_data(chat=chat_id, user=user_id, data=new_data)
|
||||
d = await store.get_data(chat=chat_id, user=user_id)
|
||||
updated_data = (data or {})
|
||||
updated_data.update(new_data or {})
|
||||
assert d == updated_data
|
||||
|
||||
async def test_has_bucket(self, store):
|
||||
assert store.has_bucket()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chat_id,user_id,data,new_data",
|
||||
[
|
||||
[12345, 54321, {"foo": "bar"}, {"bar": "foo"}],
|
||||
[12345, 54321, None, None],
|
||||
[12345, 54321, {"foo": "bar"}, None],
|
||||
[12345, 54321, None, {"bar": "foo"}],
|
||||
[12345, None, {"foo": "bar"}, {"bar": "foo"}],
|
||||
[None, 54321, {"foo": "bar"}, {"bar": "foo"}],
|
||||
],
|
||||
)
|
||||
async def test_set_get_update_bucket(self, chat_id, user_id, data, new_data, store):
|
||||
await store.reset_state(chat=chat_id, user=user_id, with_data=True)
|
||||
|
||||
await store.set_bucket(chat=chat_id, user=user_id, bucket=data)
|
||||
d = await store.get_bucket(chat=chat_id, user=user_id)
|
||||
assert d == (data or {})
|
||||
|
||||
await store.update_bucket(chat=chat_id, user=user_id, bucket=new_data)
|
||||
d = await store.get_bucket(chat=chat_id, user=user_id)
|
||||
updated_bucket = (data or {})
|
||||
updated_bucket.update(new_data or {})
|
||||
assert d == updated_bucket
|
||||
Loading…
Add table
Add a link
Reference in a new issue