diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index cb84a59f..8c6d24ae 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -1,5 +1,6 @@ import asyncio import typing +import weakref import rethinkdb as r @@ -36,7 +37,7 @@ class RethinkDBStorage(BaseStorage): """ def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None, - user=None, password=None, timeout=20, ssl=None, loop=None): + user=None, password=None, timeout=20, ssl=None, max_conn=10, loop=None): self._host = host self._port = port self._db = db @@ -47,77 +48,113 @@ class RethinkDBStorage(BaseStorage): self._timeout = timeout self._ssl = ssl or {} - self._connection: r.Connection = None + self._queue = asyncio.Queue(max_conn) + self._outstanding_connections = weakref.WeakSet() self._loop = loop or asyncio.get_event_loop() - self._lock = asyncio.Lock(loop=self._loop) - async def connection(self): + async def get_connection(self): """ Get or create connection. """ - async with self._lock: # thread-safe - if not self._connection: - self._connection = await r.connect(host=self._host, port=self._port, db=self._db, - auth_key=self._auth_key, user=self._user, - password=self._password, timeout=self._timeout, ssl=self._ssl, - io_loop=self._loop) - return self._connection + try: + while True: + conn: r.Connection = self._queue.get_nowait() + if conn.is_open(): + break + try: + await conn.close() + except r.ReqlError: + raise ConnectionNotClosed('Exception was caught while closing connection') + except asyncio.QueueEmpty: + if len(self._outstanding_connections) < self._queue.maxsize: + conn = await r.connect(host=self._host, port=self._port, db=self._db, + auth_key=self._auth_key, user=self._user, password=self._password, + timeout=self._timeout, ssl=self._ssl) + else: + conn = await self._queue.get() + + self._outstanding_connections.add(conn) + return conn + + async def put_connection(self, conn): + """ + Return connection to pool. + """ + self._queue.put_nowait(conn) + self._outstanding_connections.remove(conn) async def close(self): """ - Close connection. + Close all connections. """ - if self._connection and self._connection.is_open(): - await self._connection.close() - self._connection = None + while True: + try: + conn: r.Connection = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + + self._outstanding_connections.add(conn) + + for conn in self._outstanding_connections: + try: + await conn.close() + except r.ReqlError: + raise ConnectionNotClosed('Exception was caught while closing connection') def wait_closed(self): """ Checks if connection is closed. """ - if self._connection: + if len(self._outstanding_connections) != 0 and self._queue.qsize() != 0: raise ConnectionNotClosed return True async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() - return await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn) + conn = await self.get_connection() + result = await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn) + await self.put_connection(conn) + return result async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() - return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) + conn = await self.get_connection() + result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) + await self.put_connection(conn) + return result async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() + conn = await self.get_connection() if await r.table(self._table).get(chat).run(conn): await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn) else: await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) + await self.put_connection(conn) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() + conn = await self.get_connection() if await r.table(self._table).get(chat).run(conn): await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn) else: await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) + await self.put_connection(conn) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() + conn = await self.get_connection() if await r.table(self._table).get(chat).run(conn): await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn) else: await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) + await self.put_connection(conn) def has_bucket(self): return True @@ -125,35 +162,39 @@ class RethinkDBStorage(BaseStorage): async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() - return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) + conn = await self.get_connection() + result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) + await self.put_connection(conn) + return result async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() + conn = await self.get_connection() if await r.table(self._table).get(chat).run(conn): await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn) else: await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) + await self.put_connection(conn) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) - conn = await self.connection() + conn = await self.get_connection() if await r.table(self._table).get(chat).run(conn): await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn) else: await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) + await self.put_connection(conn) - async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]: """ Get list of all stored chat's and user's :return: list of tuples where first element is chat id and second is user id """ - conn = await self.connection() + conn = await self.get_connection() result = [] items = (await r.table(self._table).run(conn)).items @@ -164,11 +205,14 @@ class RethinkDBStorage(BaseStorage): user = int(key) result.append((chat, user)) + await self.put_connection(conn) + return result async def reset_all(self): """ Reset states in DB """ - conn = await self.connection() + conn = await self.get_connection() await r.table(self._table).delete().run(conn) + await self.put_connection(conn)