mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-06 07:50:32 +00:00
Add RethinkDB-based storage for FSM
This commit is contained in:
parent
9de31422eb
commit
f321c69f46
3 changed files with 147 additions and 2 deletions
|
|
@ -377,8 +377,10 @@ async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorag
|
|||
:param storage2: instance of RedisStorage2
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(storage1, RedisStorage)
|
||||
assert isinstance(storage2, RedisStorage2)
|
||||
if not isinstance(storage1, RedisStorage): # better than assertion
|
||||
raise TypeError(f'{type(storage1)} is not RedisStorage instance.')
|
||||
if not isinstance(storage2, RedisStorage):
|
||||
raise TypeError(f'{type(storage2)} is not RedisStorage instance.')
|
||||
|
||||
log = logging.getLogger('aiogram.RedisStorage')
|
||||
|
||||
|
|
|
|||
142
aiogram/contrib/fsm_storage/rethinkdb.py
Normal file
142
aiogram/contrib/fsm_storage/rethinkdb.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
# -*- coding:utf-8; -*-
|
||||
__all__ = ['RethinkDBStorage']
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import rethinkdb as r
|
||||
from rethinkdb.asyncio_net.net_asyncio import Connection
|
||||
|
||||
from ...dispatcher import BaseStorage
|
||||
|
||||
|
||||
r.set_loop_type('asyncio')
|
||||
|
||||
|
||||
class ConnectionNotClosed(Exception):
|
||||
"""
|
||||
Indicates that DB connection wasn't closed.
|
||||
"""
|
||||
|
||||
|
||||
class RethinkDBStorage(BaseStorage):
|
||||
"""
|
||||
RethinkDB-based storage for FSM.
|
||||
|
||||
Usage:
|
||||
|
||||
..code-block:: python3
|
||||
|
||||
storage = RethinkDBStorage(db='aiogram', table='aiogram', user='aiogram', password='aiogram_secret')
|
||||
dispatcher = Dispatcher(bot, storage=storage)
|
||||
|
||||
And need to close connection when shutdown
|
||||
|
||||
..code-clock:: python3
|
||||
|
||||
await storage.close()
|
||||
|
||||
"""
|
||||
def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None,
|
||||
user=None, password=None, timeout=20, ssl=None, loop=None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._table = table
|
||||
self._auth_key = auth_key
|
||||
self._user = user
|
||||
self._password = password
|
||||
self._timeout = timeout
|
||||
self._ssl = ssl or {}
|
||||
|
||||
self._connection: Connection = None
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._lock = asyncio.Lock(loop=self._loop)
|
||||
|
||||
async def connection(self):
|
||||
"""
|
||||
Get or create connection.
|
||||
"""
|
||||
if not self._connection:
|
||||
async with self._lock: # thread-safe
|
||||
self._connection = await r.connect(host=self._host, port=self._port, db=self._db, auth_key=self._auth_key, user=self._user,
|
||||
password=self._password, timeout=self._timeout, ssl=self._ssl, io_loop=self._loop)
|
||||
return self._connection
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close connection.
|
||||
"""
|
||||
if self._connection and self._connection.is_open():
|
||||
await self._connection.close()
|
||||
del self._connection
|
||||
self._connection = None
|
||||
|
||||
async def wait_closed(self):
|
||||
"""
|
||||
Checks if connection is closed.
|
||||
"""
|
||||
if self._connection:
|
||||
raise ConnectionNotClosed
|
||||
return True
|
||||
|
||||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
return await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn)
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Dict:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
||||
|
||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
state: typing.Optional[typing.AnyStr] = None):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
# https://stackoverflow.com/questions/24306933/how-to-make-a-rethinkdb-atomic-update-if-document-exists-insert-otherwise
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'state': state}},
|
||||
conflict='update').run(conn)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'data': r.literal(data)}},
|
||||
conflict='update').run(conn)
|
||||
|
||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None,
|
||||
**kwargs):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'data': data}},
|
||||
conflict='update').run(conn)
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
|
||||
|
||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'bucket': r.literal(bucket)}},
|
||||
conflict='update').run(conn)
|
||||
|
||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
|
||||
**kwargs):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
conn = await self.connection()
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'data': bucket}},
|
||||
conflict='update').run(conn)
|
||||
|
|
@ -5,3 +5,4 @@ pytest
|
|||
pytest-asyncio
|
||||
uvloop
|
||||
aioredis
|
||||
rethinkdb
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue