mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Implemented new RedisStorage2. More faster and with different data structure.
This commit is contained in:
parent
250a636d6e
commit
683afc714e
1 changed files with 202 additions and 1 deletions
|
|
@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis <https
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import aioredis
|
import aioredis
|
||||||
|
|
@ -10,6 +11,10 @@ import aioredis
|
||||||
from ...dispatcher.storage import BaseStorage
|
from ...dispatcher.storage import BaseStorage
|
||||||
from ...utils import json
|
from ...utils import json
|
||||||
|
|
||||||
|
STATE_KEY = 'state'
|
||||||
|
STATE_DATA_KEY = 'data'
|
||||||
|
STATE_BUCKET_KEY = 'bucket'
|
||||||
|
|
||||||
|
|
||||||
class RedisStorage(BaseStorage):
|
class RedisStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
|
|
@ -191,4 +196,200 @@ class RedisStorage(BaseStorage):
|
||||||
record = await self.get_record(chat=chat, user=user)
|
record = await self.get_record(chat=chat, user=user)
|
||||||
record_bucket = record.get('bucket', {})
|
record_bucket = record.get('bucket', {})
|
||||||
record_bucket.update(bucket, **kwargs)
|
record_bucket.update(bucket, **kwargs)
|
||||||
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data, bucket=bucket)
|
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisStorage2(BaseStorage):
|
||||||
|
"""
|
||||||
|
Busted Redis-base storage for FSM.
|
||||||
|
Works with Redis connection pool and customizable keys prefix.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
And need to close Redis connection when shutdown
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
await dp.storage.close()
|
||||||
|
await dp.storage.wait_closed()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
|
||||||
|
pool_size=10, loop=None, prefix='fsm', **kwargs):
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._db = db
|
||||||
|
self._password = password
|
||||||
|
self._ssl = ssl
|
||||||
|
self._pool_size = pool_size
|
||||||
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._prefix = (prefix,)
|
||||||
|
|
||||||
|
self._redis: aioredis.RedisConnection = None
|
||||||
|
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def redis(self) -> aioredis.Redis:
|
||||||
|
"""
|
||||||
|
Get Redis connection
|
||||||
|
|
||||||
|
This property is awaitable.
|
||||||
|
"""
|
||||||
|
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis is None:
|
||||||
|
self._redis = await aioredis.create_redis_pool((self._host, self._port),
|
||||||
|
db=self._db, password=self._password, ssl=self._ssl,
|
||||||
|
minsize=1, maxsize=self._pool_size,
|
||||||
|
loop=self._loop, **self._kwargs)
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
def generate_key(self, *parts):
|
||||||
|
return ':'.join(self._prefix + tuple(map(str, parts)))
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis and not self._redis.closed:
|
||||||
|
self._redis.close()
|
||||||
|
del self._redis
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
async def wait_closed(self):
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis:
|
||||||
|
return await self._redis.wait_closed()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
return await redis.get(key, encoding='utf8') or None
|
||||||
|
|
||||||
|
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
raw_result = await redis.get(key, encoding='utf8')
|
||||||
|
if raw_result:
|
||||||
|
return json.loads(raw_result)
|
||||||
|
return default or {}
|
||||||
|
|
||||||
|
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
state: typing.Optional[typing.AnyStr] = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
if state is None:
|
||||||
|
await redis.delete(key)
|
||||||
|
else:
|
||||||
|
await redis.set(key, state)
|
||||||
|
|
||||||
|
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
data: typing.Dict = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
await redis.set(key, json.dumps(data))
|
||||||
|
|
||||||
|
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
data: typing.Dict = None, **kwargs):
|
||||||
|
temp_data = await self.get_data(chat=chat, user=user, default={})
|
||||||
|
temp_data.update(data, **kwargs)
|
||||||
|
await self.set_data(chat=chat, user=user, data=temp_data)
|
||||||
|
|
||||||
|
def has_bucket(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
raw_result = await redis.get(key, encoding='utf8')
|
||||||
|
if raw_result:
|
||||||
|
return json.loads(raw_result)
|
||||||
|
return default or {}
|
||||||
|
|
||||||
|
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
await redis.set(key, json.dumps(bucket))
|
||||||
|
|
||||||
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None, **kwargs):
|
||||||
|
temp_bucket = await self.get_data(chat=chat, user=user)
|
||||||
|
temp_bucket.update(bucket, **kwargs)
|
||||||
|
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
||||||
|
|
||||||
|
async def reset_all(self, full=True):
|
||||||
|
"""
|
||||||
|
Reset states in DB
|
||||||
|
|
||||||
|
:param full: clean DB or clean only states
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
conn = await self.redis
|
||||||
|
|
||||||
|
if full:
|
||||||
|
conn.flushdb()
|
||||||
|
else:
|
||||||
|
keys = await conn.keys(self.generate_key('*'))
|
||||||
|
conn.delete(*keys)
|
||||||
|
|
||||||
|
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
||||||
|
"""
|
||||||
|
Get list of all stored chat's and user's
|
||||||
|
|
||||||
|
:return: list of tuples where first element is chat id and second is user id
|
||||||
|
"""
|
||||||
|
conn = await self.redis
|
||||||
|
result = []
|
||||||
|
|
||||||
|
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
|
||||||
|
for item in keys:
|
||||||
|
*_, chat, user, _ = item.split(':')
|
||||||
|
result.append((chat, user))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def import_redis1(self, redis1):
|
||||||
|
await migrate_redis1_to_redis2(redis1, self)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
|
||||||
|
"""
|
||||||
|
Helper for migrating from RedisStorage to RedisStorage2
|
||||||
|
|
||||||
|
:param storage1: instance of RedisStorage
|
||||||
|
:param storage2: instance of RedisStorage2
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert isinstance(storage1, RedisStorage)
|
||||||
|
assert isinstance(storage2, RedisStorage2)
|
||||||
|
|
||||||
|
log = logging.getLogger('aiogram.RedisStorage')
|
||||||
|
|
||||||
|
for chat, user in await storage1.get_states_list():
|
||||||
|
state = await storage1.get_state(chat=chat, user=user)
|
||||||
|
await storage2.set_state(chat=chat, user=user, state=state)
|
||||||
|
|
||||||
|
data = await storage1.get_data(chat=chat, user=user)
|
||||||
|
await storage2.set_data(chat=chat, user=user, data=data)
|
||||||
|
|
||||||
|
bucket = await storage1.get_bucket(chat=chat, user=user)
|
||||||
|
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
|
||||||
|
log.info(f"Migrated user {user} in chat {chat}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue