diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index 0d2516ad..a16da13f 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -5,7 +5,6 @@ import asyncio import typing import rethinkdb as r -from rethinkdb.asyncio_net.net_asyncio import Connection from ...dispatcher import BaseStorage @@ -49,7 +48,7 @@ class RethinkDBStorage(BaseStorage): self._timeout = timeout self._ssl = ssl or {} - self._connection: Connection = None + self._connection: r.Connection = None self._loop = loop or asyncio.get_event_loop() self._lock = asyncio.Lock(loop=self._loop) @@ -72,7 +71,7 @@ class RethinkDBStorage(BaseStorage): del self._connection self._connection = None - async def wait_closed(self): + def wait_closed(self): """ Checks if connection is closed. """ @@ -104,9 +103,10 @@ class RethinkDBStorage(BaseStorage): async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'data': r.literal(data)}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): @@ -128,16 +128,17 @@ class RethinkDBStorage(BaseStorage): async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'bucket': r.literal(bucket)}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() await r.table(self._table).insert( - {'id': chat, user: {'data': bucket}}, + {'id': chat, user: {'bucket': bucket}}, conflict='update').run(conn) async def get_states_list(self) -> typing.List[typing.Tuple[int]]: @@ -149,12 +150,13 @@ class RethinkDBStorage(BaseStorage): conn = await self.connection() result = [] - items = await r.table(self._table).get_all().run(conn).items + items = (await r.table(self._table).run(conn)).items for item in items: chat = int(item.pop('id')) - users = int(item.keys()) - result.append((chat, users)) + for key in item.keys(): + user = int(key) + result.append((chat, user)) return result @@ -163,4 +165,4 @@ class RethinkDBStorage(BaseStorage): Reset states in DB """ conn = await self.connection() - await r.table(self._table).get_all().delete().run(conn) + await r.table(self._table).delete().run(conn)