mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 07:50:32 +00:00
Updated rethinkdb.py for using rethinkdb ver.2.4.1 and higher
Removed a connection pool as an excess Removed useless variables and classes Added rethinkdb>=2.4.1 in dev_requirements.txt
This commit is contained in:
parent
c9595ead28
commit
274969c6a4
2 changed files with 41 additions and 76 deletions
|
|
@ -1,23 +1,18 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
import weakref
|
||||
|
||||
import rethinkdb as r
|
||||
import rethinkdb
|
||||
from rethinkdb.asyncio_net.net_asyncio import Connection
|
||||
|
||||
from ...dispatcher.storage import BaseStorage
|
||||
|
||||
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
|
||||
__all__ = ['RethinkDBStorage']
|
||||
|
||||
r = rethinkdb.RethinkDB()
|
||||
r.set_loop_type('asyncio')
|
||||
|
||||
|
||||
class ConnectionNotClosed(Exception):
|
||||
"""
|
||||
Indicates that DB connection wasn't closed.
|
||||
"""
|
||||
|
||||
|
||||
class RethinkDBStorage(BaseStorage):
|
||||
"""
|
||||
RethinkDB-based storage for FSM.
|
||||
|
|
@ -37,8 +32,17 @@ class RethinkDBStorage(BaseStorage):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
|
||||
user=None, password=None, timeout=20, ssl=None, max_conn=10, loop=None):
|
||||
def __init__(self,
|
||||
host: str = 'localhost',
|
||||
port: int = 28015,
|
||||
db: str = 'aiogram',
|
||||
table: str = 'aiogram',
|
||||
auth_key: typing.Optional[str] = None,
|
||||
user: typing.Optional[str] = None,
|
||||
password: typing.Optional[str] = None,
|
||||
timeout: int = 20,
|
||||
ssl: typing.Optional[dict] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
|
|
@ -48,65 +52,37 @@ class RethinkDBStorage(BaseStorage):
|
|||
self._password = password
|
||||
self._timeout = timeout
|
||||
self._ssl = ssl or {}
|
||||
self._loop = loop
|
||||
|
||||
self._queue = asyncio.Queue(max_conn)
|
||||
self._outstanding_connections = weakref.WeakSet()
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._conn: typing.Union[Connection, None] = None
|
||||
|
||||
async def get_connection(self):
|
||||
async def connect(self) -> Connection:
|
||||
"""
|
||||
Get or create connection.
|
||||
Get or create a connection.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
conn: r.Connection = self._queue.get_nowait()
|
||||
if conn.is_open():
|
||||
break
|
||||
try:
|
||||
await conn.close()
|
||||
except r.ReqlError:
|
||||
raise ConnectionNotClosed('Exception was caught while closing connection')
|
||||
except asyncio.QueueEmpty:
|
||||
if len(self._outstanding_connections) < self._queue.maxsize:
|
||||
conn = await r.connect(host=self._host, port=self._port, db=self._db,
|
||||
auth_key=self._auth_key, user=self._user, password=self._password,
|
||||
timeout=self._timeout, ssl=self._ssl)
|
||||
else:
|
||||
conn = await self._queue.get()
|
||||
|
||||
self._outstanding_connections.add(conn)
|
||||
return conn
|
||||
|
||||
async def put_connection(self, conn):
|
||||
"""
|
||||
Return connection to pool.
|
||||
"""
|
||||
self._queue.put_nowait(conn)
|
||||
self._outstanding_connections.remove(conn)
|
||||
if self._conn is None:
|
||||
self._conn = await r.connect(host=self._host,
|
||||
port=self._port,
|
||||
db=self._db,
|
||||
auth_key=self._auth_key,
|
||||
user=self._user,
|
||||
password=self._password,
|
||||
timeout=self._timeout,
|
||||
ssl=self._ssl,
|
||||
io_loop=self._loop)
|
||||
return self._conn
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connection(self):
|
||||
conn = await self.get_connection()
|
||||
conn = await self.connect()
|
||||
yield conn
|
||||
await self.put_connection(conn)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close all connections.
|
||||
Close a connection.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
conn: r.Connection = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
self._outstanding_connections.add(conn)
|
||||
|
||||
for conn in self._outstanding_connections:
|
||||
try:
|
||||
await conn.close()
|
||||
except r.ReqlError:
|
||||
raise ConnectionNotClosed('Exception was caught while closing connection')
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
async def wait_closed(self):
|
||||
"""
|
||||
|
|
@ -118,24 +94,19 @@ class RethinkDBStorage(BaseStorage):
|
|||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
result = await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
|
||||
return result
|
||||
return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Dict:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
||||
return result
|
||||
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
||||
|
||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
state: typing.Optional[typing.AnyStr] = None):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
if await r.table(self._table).get(chat).run(conn):
|
||||
await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn)
|
||||
else:
|
||||
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
|
||||
await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
|
|
@ -151,10 +122,7 @@ class RethinkDBStorage(BaseStorage):
|
|||
**kwargs):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
if await r.table(self._table).get(chat).run(conn):
|
||||
await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn)
|
||||
else:
|
||||
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
|
||||
await r.table(self._table).insert({'id': chat, user: {'data': data}}, conflict="update").run(conn)
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
|
@ -163,8 +131,7 @@ class RethinkDBStorage(BaseStorage):
|
|||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
|
||||
return result
|
||||
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
|
||||
|
||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None):
|
||||
|
|
@ -180,10 +147,7 @@ class RethinkDBStorage(BaseStorage):
|
|||
**kwargs):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
if await r.table(self._table).get(chat).run(conn):
|
||||
await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn)
|
||||
else:
|
||||
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
|
||||
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}, conflict="update").run(conn)
|
||||
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,3 +16,4 @@ sphinx-rtd-theme>=0.3.0
|
|||
sphinxcontrib-programoutput>=0.11
|
||||
aresponses>=1.0.0
|
||||
aiohttp-socks>=0.1.5
|
||||
rethinkdb>=2.4.1
|
||||
Loading…
Add table
Add a link
Reference in a new issue