From cffec23371c937366f0fa122757cd68480881e2c Mon Sep 17 00:00:00 2001 From: Bachynin Ivan Date: Wed, 12 Feb 2020 19:10:34 +0200 Subject: [PATCH] fix redis pool connection closing - add locks in closing methods of RedisStorage - don't remove reference to the pool in the `close()` method --- aiogram/contrib/fsm_storage/redis.py | 38 +++++++++++++--------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 106a7b97..bf88eff7 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -44,19 +44,19 @@ class RedisStorage(BaseStorage): self._loop = loop or asyncio.get_event_loop() self._kwargs = kwargs - self._redis: aioredis.RedisConnection = None + self._redis: typing.Optional[aioredis.RedisConnection] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def close(self): - if self._redis and not self._redis.closed: - self._redis.close() - del self._redis - self._redis = None + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() async def wait_closed(self): - if self._redis: - return await self._redis.wait_closed() - return True + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True async def redis(self) -> aioredis.RedisConnection: """ @@ -64,7 +64,7 @@ class RedisStorage(BaseStorage): """ # Use thread-safe asyncio Lock because this method without that is not safe async with self._connection_lock: - if self._redis is None: + if self._redis is None or self._redis.closed: self._redis = await aioredis.create_connection((self._host, self._port), db=self._db, password=self._password, ssl=self._ssl, loop=self._loop, @@ -144,7 +144,7 @@ class RedisStorage(BaseStorage): record_data.update(data, **kwargs) await self.set_record(chat=chat, user=user, state=record['state'], data=record_data) - async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ Get list of all stored chat's and user's @@ -220,11 +220,11 @@ class RedisStorage2(BaseStorage): """ def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, - ssl=None, pool_size=10, loop=None, prefix='fsm', - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, - **kwargs): + ssl=None, pool_size=10, loop=None, prefix='fsm', + state_ttl: int = 0, + data_ttl: int = 0, + bucket_ttl: int = 0, + **kwargs): self._host = host self._port = port self._db = db @@ -239,7 +239,7 @@ class RedisStorage2(BaseStorage): self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: aioredis.RedisConnection = None + self._redis: typing.Optional[aioredis.RedisConnection] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def redis(self) -> aioredis.Redis: @@ -248,7 +248,7 @@ class RedisStorage2(BaseStorage): """ # Use thread-safe asyncio Lock because this method without that is not safe async with self._connection_lock: - if self._redis is None: + if self._redis is None or self._redis.closed: self._redis = await aioredis.create_redis_pool((self._host, self._port), db=self._db, password=self._password, ssl=self._ssl, minsize=1, maxsize=self._pool_size, @@ -262,8 +262,6 @@ class RedisStorage2(BaseStorage): async with self._connection_lock: if self._redis and not self._redis.closed: self._redis.close() - del self._redis - self._redis = None async def wait_closed(self): async with self._connection_lock: @@ -357,7 +355,7 @@ class RedisStorage2(BaseStorage): keys = await conn.keys(self.generate_key('*')) await conn.delete(*keys) - async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ Get list of all stored chat's and user's