From f321c69f469f3affb825f316f145c6e3660e6365 Mon Sep 17 00:00:00 2001 From: Arslan 'Ars2014' Sakhapov Date: Wed, 10 Jan 2018 14:33:27 +0500 Subject: [PATCH 1/5] Add RethinkDB-based storage for FSM --- aiogram/contrib/fsm_storage/redis.py | 6 +- aiogram/contrib/fsm_storage/rethinkdb.py | 142 +++++++++++++++++++++++ dev_requirements.txt | 1 + 3 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 aiogram/contrib/fsm_storage/rethinkdb.py diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index ecf81afa..58d44127 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -377,8 +377,10 @@ async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorag :param storage2: instance of RedisStorage2 :return: """ - assert isinstance(storage1, RedisStorage) - assert isinstance(storage2, RedisStorage2) + if not isinstance(storage1, RedisStorage): # better than assertion + raise TypeError(f'{type(storage1)} is not RedisStorage instance.') + if not isinstance(storage2, RedisStorage): + raise TypeError(f'{type(storage2)} is not RedisStorage instance.') log = logging.getLogger('aiogram.RedisStorage') diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py new file mode 100644 index 00000000..f27fb4ae --- /dev/null +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -0,0 +1,142 @@ +# -*- coding:utf-8; -*- +__all__ = ['RethinkDBStorage'] + +import asyncio +import logging +import typing + +import rethinkdb as r +from rethinkdb.asyncio_net.net_asyncio import Connection + +from ...dispatcher import BaseStorage + + +r.set_loop_type('asyncio') + + +class ConnectionNotClosed(Exception): + """ + Indicates that DB connection wasn't closed. + """ + + +class RethinkDBStorage(BaseStorage): + """ + RethinkDB-based storage for FSM. + + Usage: + + ..code-block:: python3 + + storage = RethinkDBStorage(db='aiogram', table='aiogram', user='aiogram', password='aiogram_secret') + dispatcher = Dispatcher(bot, storage=storage) + + And need to close connection when shutdown + + ..code-clock:: python3 + + await storage.close() + + """ + def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None, + user=None, password=None, timeout=20, ssl=None, loop=None): + self._host = host + self._port = port + self._db = db + self._table = table + self._auth_key = auth_key + self._user = user + self._password = password + self._timeout = timeout + self._ssl = ssl or {} + + self._connection: Connection = None + self._loop = loop or asyncio.get_event_loop() + self._lock = asyncio.Lock(loop=self._loop) + + async def connection(self): + """ + Get or create connection. + """ + if not self._connection: + async with self._lock: # thread-safe + self._connection = await r.connect(host=self._host, port=self._port, db=self._db, auth_key=self._auth_key, user=self._user, + password=self._password, timeout=self._timeout, ssl=self._ssl, io_loop=self._loop) + return self._connection + + async def close(self): + """ + Close connection. + """ + if self._connection and self._connection.is_open(): + await self._connection.close() + del self._connection + self._connection = None + + async def wait_closed(self): + """ + Checks if connection is closed. + """ + if self._connection: + raise ConnectionNotClosed + return True + + async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + return await r.table(self._table).get(chat)[user]['state'].default(default or '').run(conn) + + async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Dict: + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) + + async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + state: typing.Optional[typing.AnyStr] = None): + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + # https://stackoverflow.com/questions/24306933/how-to-make-a-rethinkdb-atomic-update-if-document-exists-insert-otherwise + await r.table(self._table).insert( + {'id': chat, user: {'state': state}}, + conflict='update').run(conn) + + async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + await r.table(self._table).insert( + {'id': chat, user: {'data': r.literal(data)}}, + conflict='update').run(conn) + + async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, + **kwargs): + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + await r.table(self._table).insert( + {'id': chat, user: {'data': data}}, + conflict='update').run(conn) + + def has_bucket(self): + return True + + async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) + + async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + await r.table(self._table).insert( + {'id': chat, user: {'bucket': r.literal(bucket)}}, + conflict='update').run(conn) + + async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, + **kwargs): + chat, user = map(str, self.check_address(chat=chat, user=user)) + conn = await self.connection() + await r.table(self._table).insert( + {'id': chat, user: {'data': bucket}}, + conflict='update').run(conn) diff --git a/dev_requirements.txt b/dev_requirements.txt index 7fd66b33..92eaaa4b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -5,3 +5,4 @@ pytest pytest-asyncio uvloop aioredis +rethinkdb From afde20eccdb2b2b7ddfc955cb1e68658979b8ac7 Mon Sep 17 00:00:00 2001 From: Arslan 'Ars2014' Sakhapov Date: Wed, 10 Jan 2018 14:53:14 +0500 Subject: [PATCH 2/5] Add 'get_states_list' and 'reset_all' methods --- aiogram/contrib/fsm_storage/rethinkdb.py | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index f27fb4ae..0d2516ad 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -2,7 +2,6 @@ __all__ = ['RethinkDBStorage'] import asyncio -import logging import typing import rethinkdb as r @@ -140,3 +139,28 @@ class RethinkDBStorage(BaseStorage): await r.table(self._table).insert( {'id': chat, user: {'data': bucket}}, conflict='update').run(conn) + + async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + """ + Get list of all stored chat's and user's + + :return: list of tuples where first element is chat id and second is user id + """ + conn = await self.connection() + result = [] + + items = await r.table(self._table).get_all().run(conn).items + + for item in items: + chat = int(item.pop('id')) + users = int(item.keys()) + result.append((chat, users)) + + return result + + async def reset_all(self): + """ + Reset states in DB + """ + conn = await self.connection() + await r.table(self._table).get_all().delete().run(conn) From efeac6b923a23b5088afc2179c2e02d3f827ef70 Mon Sep 17 00:00:00 2001 From: Arslan 'Ars2014' Sakhapov Date: Wed, 10 Jan 2018 15:54:15 +0500 Subject: [PATCH 3/5] Fix errors --- aiogram/contrib/fsm_storage/rethinkdb.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index 0d2516ad..a16da13f 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -5,7 +5,6 @@ import asyncio import typing import rethinkdb as r -from rethinkdb.asyncio_net.net_asyncio import Connection from ...dispatcher import BaseStorage @@ -49,7 +48,7 @@ class RethinkDBStorage(BaseStorage): self._timeout = timeout self._ssl = ssl or {} - self._connection: Connection = None + self._connection: r.Connection = None self._loop = loop or asyncio.get_event_loop() self._lock = asyncio.Lock(loop=self._loop) @@ -72,7 +71,7 @@ class RethinkDBStorage(BaseStorage): del self._connection self._connection = None - async def wait_closed(self): + def wait_closed(self): """ Checks if connection is closed. """ @@ -104,9 +103,10 @@ class RethinkDBStorage(BaseStorage): async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'data': r.literal(data)}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): @@ -128,16 +128,17 @@ class RethinkDBStorage(BaseStorage): async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'bucket': r.literal(bucket)}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() await r.table(self._table).insert( - {'id': chat, user: {'data': bucket}}, + {'id': chat, user: {'bucket': bucket}}, conflict='update').run(conn) async def get_states_list(self) -> typing.List[typing.Tuple[int]]: @@ -149,12 +150,13 @@ class RethinkDBStorage(BaseStorage): conn = await self.connection() result = [] - items = await r.table(self._table).get_all().run(conn).items + items = (await r.table(self._table).run(conn)).items for item in items: chat = int(item.pop('id')) - users = int(item.keys()) - result.append((chat, users)) + for key in item.keys(): + user = int(key) + result.append((chat, user)) return result @@ -163,4 +165,4 @@ class RethinkDBStorage(BaseStorage): Reset states in DB """ conn = await self.connection() - await r.table(self._table).get_all().delete().run(conn) + await r.table(self._table).delete().run(conn) From df74cecdef5fe5b0c02ed6bb5af4a8707df7c7a9 Mon Sep 17 00:00:00 2001 From: Arslan 'Ars2014' Sakhapov Date: Thu, 11 Jan 2018 00:33:33 +0500 Subject: [PATCH 4/5] Some improvements --- aiogram/contrib/fsm_storage/rethinkdb.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index a16da13f..c1021612 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -56,8 +56,8 @@ class RethinkDBStorage(BaseStorage): """ Get or create connection. """ - if not self._connection: - async with self._lock: # thread-safe + async with self._lock: # thread-safe + if not self._connection: self._connection = await r.connect(host=self._host, port=self._port, db=self._db, auth_key=self._auth_key, user=self._user, password=self._password, timeout=self._timeout, ssl=self._ssl, io_loop=self._loop) return self._connection @@ -68,7 +68,6 @@ class RethinkDBStorage(BaseStorage): """ if self._connection and self._connection.is_open(): await self._connection.close() - del self._connection self._connection = None def wait_closed(self): From 539e67d394d374e25edc35ecbde1edd2c976c814 Mon Sep 17 00:00:00 2001 From: Arslan 'Ars2014' Sakhapov Date: Thu, 11 Jan 2018 01:03:05 +0500 Subject: [PATCH 5/5] Some improvements v2 --- aiogram/contrib/fsm_storage/rethinkdb.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index c1021612..ed8733b1 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -94,10 +94,10 @@ class RethinkDBStorage(BaseStorage): state: typing.Optional[typing.AnyStr] = None): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - # https://stackoverflow.com/questions/24306933/how-to-make-a-rethinkdb-atomic-update-if-document-exists-insert-otherwise - await r.table(self._table).insert( - {'id': chat, user: {'state': state}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = map(str, self.check_address(chat=chat, user=user)) @@ -111,9 +111,10 @@ class RethinkDBStorage(BaseStorage): **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'data': data}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) def has_bucket(self): return True @@ -136,9 +137,10 @@ class RethinkDBStorage(BaseStorage): **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) conn = await self.connection() - await r.table(self._table).insert( - {'id': chat, user: {'bucket': bucket}}, - conflict='update').run(conn) + if await r.table(self._table).get(chat).run(conn): + await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn) + else: + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) async def get_states_list(self) -> typing.List[typing.Tuple[int]]: """