mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Merge pull request #15 from Ars2014/dev-1.x
Using connection pool in RethinkDB driver
This commit is contained in:
commit
4dcccaae36
1 changed files with 74 additions and 30 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
|
import weakref
|
||||||
|
|
||||||
import rethinkdb as r
|
import rethinkdb as r
|
||||||
|
|
||||||
|
|
@ -36,7 +37,7 @@ class RethinkDBStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
|
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
|
||||||
user=None, password=None, timeout=20, ssl=None, loop=None):
|
user=None, password=None, timeout=20, ssl=None, max_conn=10, loop=None):
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
@ -47,77 +48,113 @@ class RethinkDBStorage(BaseStorage):
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._ssl = ssl or {}
|
self._ssl = ssl or {}
|
||||||
|
|
||||||
self._connection: r.Connection = None
|
self._queue = asyncio.Queue(max_conn)
|
||||||
|
self._outstanding_connections = weakref.WeakSet()
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._lock = asyncio.Lock(loop=self._loop)
|
|
||||||
|
|
||||||
async def connection(self):
|
async def get_connection(self):
|
||||||
"""
|
"""
|
||||||
Get or create connection.
|
Get or create connection.
|
||||||
"""
|
"""
|
||||||
async with self._lock: # thread-safe
|
try:
|
||||||
if not self._connection:
|
while True:
|
||||||
self._connection = await r.connect(host=self._host, port=self._port, db=self._db,
|
conn: r.Connection = self._queue.get_nowait()
|
||||||
auth_key=self._auth_key, user=self._user,
|
if conn.is_open():
|
||||||
password=self._password, timeout=self._timeout, ssl=self._ssl,
|
break
|
||||||
io_loop=self._loop)
|
try:
|
||||||
return self._connection
|
await conn.close()
|
||||||
|
except r.ReqlError:
|
||||||
|
raise ConnectionNotClosed('Exception was caught while closing connection')
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
if len(self._outstanding_connections) < self._queue.maxsize:
|
||||||
|
conn = await r.connect(host=self._host, port=self._port, db=self._db,
|
||||||
|
auth_key=self._auth_key, user=self._user, password=self._password,
|
||||||
|
timeout=self._timeout, ssl=self._ssl)
|
||||||
|
else:
|
||||||
|
conn = await self._queue.get()
|
||||||
|
|
||||||
|
self._outstanding_connections.add(conn)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
async def put_connection(self, conn):
|
||||||
|
"""
|
||||||
|
Return connection to pool.
|
||||||
|
"""
|
||||||
|
self._queue.put_nowait(conn)
|
||||||
|
self._outstanding_connections.remove(conn)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
Close connection.
|
Close all connections.
|
||||||
"""
|
"""
|
||||||
if self._connection and self._connection.is_open():
|
while True:
|
||||||
await self._connection.close()
|
try:
|
||||||
self._connection = None
|
conn: r.Connection = self._queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
self._outstanding_connections.add(conn)
|
||||||
|
|
||||||
|
for conn in self._outstanding_connections:
|
||||||
|
try:
|
||||||
|
await conn.close()
|
||||||
|
except r.ReqlError:
|
||||||
|
raise ConnectionNotClosed('Exception was caught while closing connection')
|
||||||
|
|
||||||
def wait_closed(self):
|
def wait_closed(self):
|
||||||
"""
|
"""
|
||||||
Checks if connection is closed.
|
Checks if connection is closed.
|
||||||
"""
|
"""
|
||||||
if self._connection:
|
if len(self._outstanding_connections) != 0 and self._queue.qsize() != 0:
|
||||||
raise ConnectionNotClosed
|
raise ConnectionNotClosed
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
return await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn)
|
result = await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
default: typing.Optional[str] = None) -> typing.Dict:
|
default: typing.Optional[str] = None) -> typing.Dict:
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
return result
|
||||||
|
|
||||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
state: typing.Optional[typing.AnyStr] = None):
|
state: typing.Optional[typing.AnyStr] = None):
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
if await r.table(self._table).get(chat).run(conn):
|
if await r.table(self._table).get(chat).run(conn):
|
||||||
await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn)
|
await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn)
|
||||||
else:
|
else:
|
||||||
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
|
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
data: typing.Dict = None):
|
data: typing.Dict = None):
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
if await r.table(self._table).get(chat).run(conn):
|
if await r.table(self._table).get(chat).run(conn):
|
||||||
await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn)
|
await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn)
|
||||||
else:
|
else:
|
||||||
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
|
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
data: typing.Dict = None,
|
data: typing.Dict = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
if await r.table(self._table).get(chat).run(conn):
|
if await r.table(self._table).get(chat).run(conn):
|
||||||
await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn)
|
await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn)
|
||||||
else:
|
else:
|
||||||
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
|
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
def has_bucket(self):
|
def has_bucket(self):
|
||||||
return True
|
return True
|
||||||
|
|
@ -125,35 +162,39 @@ class RethinkDBStorage(BaseStorage):
|
||||||
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
|
result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
return result
|
||||||
|
|
||||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
bucket: typing.Dict = None):
|
bucket: typing.Dict = None):
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
if await r.table(self._table).get(chat).run(conn):
|
if await r.table(self._table).get(chat).run(conn):
|
||||||
await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn)
|
await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn)
|
||||||
else:
|
else:
|
||||||
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
|
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
|
user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
if await r.table(self._table).get(chat).run(conn):
|
if await r.table(self._table).get(chat).run(conn):
|
||||||
await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn)
|
await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn)
|
||||||
else:
|
else:
|
||||||
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
|
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
Get list of all stored chat's and user's
|
Get list of all stored chat's and user's
|
||||||
|
|
||||||
:return: list of tuples where first element is chat id and second is user id
|
:return: list of tuples where first element is chat id and second is user id
|
||||||
"""
|
"""
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
items = (await r.table(self._table).run(conn)).items
|
items = (await r.table(self._table).run(conn)).items
|
||||||
|
|
@ -164,11 +205,14 @@ class RethinkDBStorage(BaseStorage):
|
||||||
user = int(key)
|
user = int(key)
|
||||||
result.append((chat, user))
|
result.append((chat, user))
|
||||||
|
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def reset_all(self):
|
async def reset_all(self):
|
||||||
"""
|
"""
|
||||||
Reset states in DB
|
Reset states in DB
|
||||||
"""
|
"""
|
||||||
conn = await self.connection()
|
conn = await self.get_connection()
|
||||||
await r.table(self._table).delete().run(conn)
|
await r.table(self._table).delete().run(conn)
|
||||||
|
await self.put_connection(conn)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue