Small changes.

This commit is contained in:
Alex Root Junior 2018-01-27 07:45:46 +02:00
parent f874310965
commit ba44ca67fa
2 changed files with 18 additions and 11 deletions

View file

@ -1,7 +1,9 @@
from . import api
from .base import BaseBot from .base import BaseBot
from .bot import Bot from .bot import Bot
__all__ = [ __all__ = [
'BaseBot', 'BaseBot',
'Bot' 'Bot',
'api'
] ]

View file

@ -1,6 +1,3 @@
# -*- coding:utf-8; -*-
__all__ = ['RethinkDBStorage']
import asyncio import asyncio
import typing import typing
@ -8,6 +5,7 @@ import rethinkdb as r
from ...dispatcher import BaseStorage from ...dispatcher import BaseStorage
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
r.set_loop_type('asyncio') r.set_loop_type('asyncio')
@ -36,6 +34,7 @@ class RethinkDBStorage(BaseStorage):
await storage.close() await storage.close()
""" """
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None, def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
user=None, password=None, timeout=20, ssl=None, loop=None): user=None, password=None, timeout=20, ssl=None, loop=None):
self._host = host self._host = host
@ -58,8 +57,10 @@ class RethinkDBStorage(BaseStorage):
""" """
async with self._lock: # thread-safe async with self._lock: # thread-safe
if not self._connection: if not self._connection:
self._connection = await r.connect(host=self._host, port=self._port, db=self._db, auth_key=self._auth_key, user=self._user, self._connection = await r.connect(host=self._host, port=self._port, db=self._db,
password=self._password, timeout=self._timeout, ssl=self._ssl, io_loop=self._loop) auth_key=self._auth_key, user=self._user,
password=self._password, timeout=self._timeout, ssl=self._ssl,
io_loop=self._loop)
return self._connection return self._connection
async def close(self): async def close(self):
@ -99,7 +100,8 @@ class RethinkDBStorage(BaseStorage):
else: else:
await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
if await r.table(self._table).get(chat).run(conn): if await r.table(self._table).get(chat).run(conn):
@ -107,7 +109,8 @@ class RethinkDBStorage(BaseStorage):
else: else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs): **kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
@ -125,7 +128,8 @@ class RethinkDBStorage(BaseStorage):
conn = await self.connection() conn = await self.connection()
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()
if await r.table(self._table).get(chat).run(conn): if await r.table(self._table).get(chat).run(conn):
@ -133,7 +137,8 @@ class RethinkDBStorage(BaseStorage):
else: else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
**kwargs): **kwargs):
chat, user = map(str, self.check_address(chat=chat, user=user)) chat, user = map(str, self.check_address(chat=chat, user=user))
conn = await self.connection() conn = await self.connection()