Fix errors

This commit is contained in:
Arslan 'Ars2014' Sakhapov 2018-01-10 15:54:15 +05:00
parent afde20eccd
commit efeac6b923

View file

@ -5,7 +5,6 @@ import asyncio
import typing import typing
import rethinkdb as r import rethinkdb as r
from rethinkdb.asyncio_net.net_asyncio import Connection
from ...dispatcher import BaseStorage from ...dispatcher import BaseStorage
@ -49,7 +48,7 @@ class RethinkDBStorage(BaseStorage):
self._timeout = timeout self._timeout = timeout
self._ssl = ssl or {} self._ssl = ssl or {}
self._connection: Connection = None self._connection: r.Connection = None
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._lock = asyncio.Lock(loop=self._loop) self._lock = asyncio.Lock(loop=self._loop)
@ -72,7 +71,7 @@ class RethinkDBStorage(BaseStorage):
del self._connection del self._connection
self._connection = None self._connection = None
async def wait_closed(self): def wait_closed(self):
""" """
Checks if connection is closed. Checks if connection is closed.
""" """
@ -104,9 +103,10 @@ class RethinkDBStorage(BaseStorage):
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
await r.table(self._table).insert( if await r.table(self._table).get(chat).run(conn):
{'id': chat, user: {'data': r.literal(data)}}, await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn)
conflict='update').run(conn) else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None,
**kwargs): **kwargs):
@ -128,16 +128,17 @@ class RethinkDBStorage(BaseStorage):
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
await r.table(self._table).insert( if await r.table(self._table).get(chat).run(conn):
{'id': chat, user: {'bucket': r.literal(bucket)}}, await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn)
conflict='update').run(conn) else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
**kwargs): **kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
await r.table(self._table).insert( await r.table(self._table).insert(
{'id': chat, user: {'data': bucket}}, {'id': chat, user: {'bucket': bucket}},
conflict='update').run(conn) conflict='update').run(conn)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]: async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
@ -149,12 +150,13 @@ class RethinkDBStorage(BaseStorage):
conn = await self.connection() conn = await self.connection()
result = [] result = []
items = await r.table(self._table).get_all().run(conn).items items = (await r.table(self._table).run(conn)).items
for item in items: for item in items:
chat = int(item.pop('id')) chat = int(item.pop('id'))
users = int(item.keys()) for key in item.keys():
result.append((chat, users)) user = int(key)
result.append((chat, user))
return result return result
@ -163,4 +165,4 @@ class RethinkDBStorage(BaseStorage):
Reset states in DB Reset states in DB
""" """
conn = await self.connection() conn = await self.connection()
await r.table(self._table).get_all().delete().run(conn) await r.table(self._table).delete().run(conn)