diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 30caa6e7..ecf81afa 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis aioredis.Redis: + """ + Get Redis connection + + This property is awaitable. + """ + # Use thread-safe asyncio Lock because this method without that is not safe + async with self._connection_lock: + if self._redis is None: + self._redis = await aioredis.create_redis_pool((self._host, self._port), + db=self._db, password=self._password, ssl=self._ssl, + minsize=1, maxsize=self._pool_size, + loop=self._loop, **self._kwargs) + return self._redis + + def generate_key(self, *parts): + return ':'.join(self._prefix + tuple(map(str, parts))) + + async def close(self): + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() + del self._redis + self._redis = None + + async def wait_closed(self): + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True + + async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_KEY) + redis = await self.redis + return await redis.get(key, encoding='utf8') or None + + async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_DATA_KEY) + redis = await self.redis + raw_result = await redis.get(key, encoding='utf8') + if raw_result: + return json.loads(raw_result) + return default or {} + + async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + state: typing.Optional[typing.AnyStr] = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_KEY) + redis = await self.redis + if state is None: + await redis.delete(key) + else: + await redis.set(key, state) + + async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + data: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_DATA_KEY) + redis = await self.redis + await redis.set(key, json.dumps(data)) + + async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + data: typing.Dict = None, **kwargs): + temp_data = await self.get_data(chat=chat, user=user, default={}) + temp_data.update(data, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_data) + + def has_bucket(self): + return True + + async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_BUCKET_KEY) + redis = await self.redis + raw_result = await redis.get(key, encoding='utf8') + if raw_result: + return json.loads(raw_result) + return default or {} + + async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_BUCKET_KEY) + redis = await self.redis + await redis.set(key, json.dumps(bucket)) + + async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None, **kwargs): + temp_bucket = await self.get_data(chat=chat, user=user) + temp_bucket.update(bucket, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_bucket) + + async def reset_all(self, full=True): + """ + Reset states in DB + + :param full: clean DB or clean only states + :return: + """ + conn = await self.redis + + if full: + conn.flushdb() + else: + keys = await conn.keys(self.generate_key('*')) + conn.delete(*keys) + + async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + """ + Get list of all stored chat's and user's + + :return: list of tuples where first element is chat id and second is user id + """ + conn = await self.redis + result = [] + + keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') + for item in keys: + *_, chat, user, _ = item.split(':') + result.append((chat, user)) + + return result + + async def import_redis1(self, redis1): + await migrate_redis1_to_redis2(redis1, self) + + +async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2): + """ + Helper for migrating from RedisStorage to RedisStorage2 + + :param storage1: instance of RedisStorage + :param storage2: instance of RedisStorage2 + :return: + """ + assert isinstance(storage1, RedisStorage) + assert isinstance(storage2, RedisStorage2) + + log = logging.getLogger('aiogram.RedisStorage') + + for chat, user in await storage1.get_states_list(): + state = await storage1.get_state(chat=chat, user=user) + await storage2.set_state(chat=chat, user=user, state=state) + + data = await storage1.get_data(chat=chat, user=user) + await storage2.set_data(chat=chat, user=user, data=data) + + bucket = await storage1.get_bucket(chat=chat, user=user) + await storage2.set_bucket(chat=chat, user=user, bucket=bucket) + + log.info(f"Migrated user {user} in chat {chat}")