Implemented new RedisStorage2. More faster and with different data structure.

This commit is contained in:
Alex Root Junior 2017-12-02 22:42:05 +02:00
parent 250a636d6e
commit 683afc714e

View file

@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis <https
"""
import asyncio
import logging
import typing
import aioredis
@ -10,6 +11,10 @@ import aioredis
from ...dispatcher.storage import BaseStorage
from ...utils import json
STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
class RedisStorage(BaseStorage):
"""
@ -191,4 +196,200 @@ class RedisStorage(BaseStorage):
record = await self.get_record(chat=chat, user=user)
record_bucket = record.get('bucket', {})
record_bucket.update(bucket, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data, bucket=bucket)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
Works with Redis connection pool and customizable keys prefix.
Usage:
.. code-block:: python3
storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
dp = Dispatcher(bot, storage=storage)
And need to close Redis connection when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
pool_size=10, loop=None, prefix='fsm', **kwargs):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs
self._prefix = (prefix,)
self._redis: aioredis.RedisConnection = None
self._connection_lock = asyncio.Lock(loop=self._loop)
@property
async def redis(self) -> aioredis.Redis:
"""
Get Redis connection
This property is awaitable.
"""
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
self._redis = await aioredis.create_redis_pool((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
minsize=1, maxsize=self._pool_size,
loop=self._loop, **self._kwargs)
return self._redis
def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))
async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
del self._redis
self._redis = None
async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis
return await redis.get(key, encoding='utf8') or None
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis
raw_result = await redis.get(key, encoding='utf8')
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis
if state is None:
await redis.delete(key)
else:
await redis.set(key, state)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis
await redis.set(key, json.dumps(data))
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis
raw_result = await redis.get(key, encoding='utf8')
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis
await redis.set(key, json.dumps(bucket))
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
temp_bucket = await self.get_data(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
conn = await self.redis
if full:
conn.flushdb()
else:
keys = await conn.keys(self.generate_key('*'))
conn.delete(*keys)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
conn = await self.redis
result = []
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
for item in keys:
*_, chat, user, _ = item.split(':')
result.append((chat, user))
return result
async def import_redis1(self, redis1):
await migrate_redis1_to_redis2(redis1, self)
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
"""
Helper for migrating from RedisStorage to RedisStorage2
:param storage1: instance of RedisStorage
:param storage2: instance of RedisStorage2
:return:
"""
assert isinstance(storage1, RedisStorage)
assert isinstance(storage2, RedisStorage2)
log = logging.getLogger('aiogram.RedisStorage')
for chat, user in await storage1.get_states_list():
state = await storage1.get_state(chat=chat, user=user)
await storage2.set_state(chat=chat, user=user, state=state)
data = await storage1.get_data(chat=chat, user=user)
await storage2.set_data(chat=chat, user=user, data=data)
bucket = await storage1.get_bucket(chat=chat, user=user)
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
log.info(f"Migrated user {user} in chat {chat}")