Using connection pool in RethinkDB driver

This commit is contained in:
Arslan 'Ars2014' Sakhapov 2018-02-17 22:25:23 +05:00
parent 4a9533ded5
commit 99f5a89f70

View file

@ -1,5 +1,6 @@
import asyncio
import typing
import weakref
import rethinkdb as r
@ -47,77 +48,110 @@ class RethinkDBStorage(BaseStorage):
self._timeout = timeout
self._ssl = ssl or {}
self._connection: r.Connection = None
self._queue = asyncio.Queue()
self._outstanding_connections = weakref.WeakSet()
self._loop = loop or asyncio.get_event_loop()
self._lock = asyncio.Lock(loop=self._loop)
async def connection(self):
async def get_connection(self):
"""
Get or create connection.
"""
async with self._lock: # thread-safe
if not self._connection:
self._connection = await r.connect(host=self._host, port=self._port, db=self._db,
auth_key=self._auth_key, user=self._user,
password=self._password, timeout=self._timeout, ssl=self._ssl,
io_loop=self._loop)
return self._connection
try:
while True:
conn: r.Connection = self._queue.get_nowait()
if conn.is_open():
break
try:
await conn.close()
except r.ReqlError:
raise ConnectionNotClosed('Exception was caught while closing connection')
except asyncio.QueueEmpty:
conn = await r.connect(host=self._host, port=self._port, db=self._db,
auth_key=self._auth_key, user=self._user, password=self._password, timeout=self._timeout,
ssl=self._ssl)
self._outstanding_connections.add(conn)
return conn
async def put_connection(self, conn):
"""
Return connection to pool.
"""
self._queue.put_nowait(conn)
self._outstanding_connections.remove(conn)
async def close(self):
"""
Close connection.
Close all connections.
"""
if self._connection and self._connection.is_open():
await self._connection.close()
self._connection = None
while True:
try:
conn: r.Connection = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
self._outstanding_connections.add(conn)
for conn in self._outstanding_connections:
try:
await conn.close()
except r.ReqlError:
raise ConnectionNotClosed('Exception was caught while closing connection')
def wait_closed(self):
"""
Checks if connection is closed.
"""
if self._connection:
if len(self._outstanding_connections) != 0 and self._queue.qsize() != 0:
raise ConnectionNotClosed
return True
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
return await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn)
conn = await self.get_connection()
result = await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn)
await self.put_connection(conn)
return result
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
conn = await self.get_connection()
result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
await self.put_connection(conn)
return result
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
conn = await self.get_connection()
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
await self.put_connection(conn)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
conn = await self.get_connection()
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
await self.put_connection(conn)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
conn = await self.get_connection()
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
await self.put_connection(conn)
def has_bucket(self):
return True
@ -125,27 +159,31 @@ class RethinkDBStorage(BaseStorage):
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
conn = await self.get_connection()
result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
await self.put_connection(conn)
return result
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
conn = await self.get_connection()
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
await self.put_connection(conn)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
**kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection()
conn = await self.get_connection()
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
await self.put_connection(conn)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
"""
@ -153,7 +191,7 @@ class RethinkDBStorage(BaseStorage):
:return: list of tuples where first element is chat id and second is user id
"""
conn = await self.connection()
conn = await self.get_connection()
result = []
items = (await r.table(self._table).run(conn)).items
@ -164,11 +202,14 @@ class RethinkDBStorage(BaseStorage):
user = int(key)
result.append((chat, user))
await self.put_connection(conn)
return result
async def reset_all(self):
"""
Reset states in DB
"""
conn = await self.connection()
conn = await self.get_connection()
await r.table(self._table).delete().run(conn)
await self.put_connection(conn)