fix redis pool connection closing

- add locks in closing methods of RedisStorage
- don't remove reference to the pool in the `close()` method
This commit is contained in:
Bachynin Ivan 2020-02-12 19:10:34 +02:00
parent 20ba5faf5c
commit cffec23371

View file

@ -44,19 +44,19 @@ class RedisStorage(BaseStorage):
self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs
self._redis: aioredis.RedisConnection = None
self._redis: typing.Optional[aioredis.RedisConnection] = None
self._connection_lock = asyncio.Lock(loop=self._loop)
async def close(self):
if self._redis and not self._redis.closed:
self._redis.close()
del self._redis
self._redis = None
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
async def wait_closed(self):
if self._redis:
return await self._redis.wait_closed()
return True
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
async def redis(self) -> aioredis.RedisConnection:
"""
@ -64,7 +64,7 @@ class RedisStorage(BaseStorage):
"""
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_connection((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
loop=self._loop,
@ -144,7 +144,7 @@ class RedisStorage(BaseStorage):
record_data.update(data, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's
@ -220,11 +220,11 @@ class RedisStorage2(BaseStorage):
"""
def __init__(self, host: str = 'localhost', port=6379, db=None, password=None,
ssl=None, pool_size=10, loop=None, prefix='fsm',
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
**kwargs):
ssl=None, pool_size=10, loop=None, prefix='fsm',
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
**kwargs):
self._host = host
self._port = port
self._db = db
@ -239,7 +239,7 @@ class RedisStorage2(BaseStorage):
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl
self._redis: aioredis.RedisConnection = None
self._redis: typing.Optional[aioredis.RedisConnection] = None
self._connection_lock = asyncio.Lock(loop=self._loop)
async def redis(self) -> aioredis.Redis:
@ -248,7 +248,7 @@ class RedisStorage2(BaseStorage):
"""
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_redis_pool((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
minsize=1, maxsize=self._pool_size,
@ -262,8 +262,6 @@ class RedisStorage2(BaseStorage):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
del self._redis
self._redis = None
async def wait_closed(self):
async with self._connection_lock:
@ -357,7 +355,7 @@ class RedisStorage2(BaseStorage):
keys = await conn.keys(self.generate_key('*'))
await conn.delete(*keys)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's