Updated rethinkdb.py for using rethinkdb ver.2.4.1 and higher

Removed a connection pool as an excess
 Removed useless variables and classes
Added rethinkdb>=2.4.1 in dev_requirements.txt
This commit is contained in:
Arslan 'Ars2014' Sakhapov 2019-03-31 19:17:32 +05:00
parent c9595ead28
commit 274969c6a4
2 changed files with 41 additions and 76 deletions

View file

@ -1,23 +1,18 @@
import asyncio
import contextlib
import typing
import weakref
import rethinkdb as r
import rethinkdb
from rethinkdb.asyncio_net.net_asyncio import Connection
from ...dispatcher.storage import BaseStorage
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
__all__ = ['RethinkDBStorage']
r = rethinkdb.RethinkDB()
r.set_loop_type('asyncio')
class ConnectionNotClosed(Exception):
"""
Indicates that DB connection wasn't closed.
"""
class RethinkDBStorage(BaseStorage):
"""
RethinkDB-based storage for FSM.
@ -37,8 +32,17 @@ class RethinkDBStorage(BaseStorage):
"""
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
user=None, password=None, timeout=20, ssl=None, max_conn=10, loop=None):
def __init__(self,
host: str = 'localhost',
port: int = 28015,
db: str = 'aiogram',
table: str = 'aiogram',
auth_key: typing.Optional[str] = None,
user: typing.Optional[str] = None,
password: typing.Optional[str] = None,
timeout: int = 20,
ssl: typing.Optional[dict] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None):
self._host = host
self._port = port
self._db = db
@ -48,65 +52,37 @@ class RethinkDBStorage(BaseStorage):
self._password = password
self._timeout = timeout
self._ssl = ssl or {}
self._loop = loop
self._queue = asyncio.Queue(max_conn)
self._outstanding_connections = weakref.WeakSet()
self._loop = loop or asyncio.get_event_loop()
self._conn: typing.Union[Connection, None] = None
async def get_connection(self):
async def connect(self) -> Connection:
"""
Get or create connection.
Get or create a connection.
"""
try:
while True:
conn: r.Connection = self._queue.get_nowait()
if conn.is_open():
break
try:
await conn.close()
except r.ReqlError:
raise ConnectionNotClosed('Exception was caught while closing connection')
except asyncio.QueueEmpty:
if len(self._outstanding_connections) < self._queue.maxsize:
conn = await r.connect(host=self._host, port=self._port, db=self._db,
auth_key=self._auth_key, user=self._user, password=self._password,
timeout=self._timeout, ssl=self._ssl)
else:
conn = await self._queue.get()
self._outstanding_connections.add(conn)
return conn
async def put_connection(self, conn):
"""
Return connection to pool.
"""
self._queue.put_nowait(conn)
self._outstanding_connections.remove(conn)
if self._conn is None:
self._conn = await r.connect(host=self._host,
port=self._port,
db=self._db,
auth_key=self._auth_key,
user=self._user,
password=self._password,
timeout=self._timeout,
ssl=self._ssl,
io_loop=self._loop)
return self._conn
@contextlib.asynccontextmanager
async def connection(self):
conn = await self.get_connection()
conn = await self.connect()
yield conn
await self.put_connection(conn)
async def close(self):
"""
Close all connections.
Close a connection.
"""
while True:
try:
conn: r.Connection = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
self._outstanding_connections.add(conn)
for conn in self._outstanding_connections:
try:
await conn.close()
except r.ReqlError:
raise ConnectionNotClosed('Exception was caught while closing connection')
self._conn.close()
self._conn = None
async def wait_closed(self):
"""
@ -118,24 +94,19 @@ class RethinkDBStorage(BaseStorage):
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
result = await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
return result
return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
return result
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
@ -151,10 +122,7 @@ class RethinkDBStorage(BaseStorage):
**kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
await r.table(self._table).insert({'id': chat, user: {'data': data}}, conflict="update").run(conn)
def has_bucket(self):
return True
@ -163,8 +131,7 @@ class RethinkDBStorage(BaseStorage):
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
return result
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
@ -180,10 +147,7 @@ class RethinkDBStorage(BaseStorage):
**kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}, conflict="update").run(conn)
async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]:
"""

View file

@ -16,3 +16,4 @@ sphinx-rtd-theme>=0.3.0
sphinxcontrib-programoutput>=0.11
aresponses>=1.0.0
aiohttp-socks>=0.1.5
rethinkdb>=2.4.1