diff --git a/aiogram/dispatcher/state.py b/aiogram/dispatcher/state.py index a35f6998..a2f087a5 100644 --- a/aiogram/dispatcher/state.py +++ b/aiogram/dispatcher/state.py @@ -1,4 +1,6 @@ +import json import logging +import os from .handler import SkipHandler @@ -6,7 +8,6 @@ log = logging.getLogger('aiogram.StateMachine') # TODO: Provide async storage -# TODO: Provide inline/callback and etc updates. class BaseStorage: @@ -130,58 +131,17 @@ class BaseStorage: """ raise NotImplementedError - def __setitem__(self, key, value): - """ - Here you can use key or slice-key - - >>> storage[chat:user] = "new state" - or - >>> storage[chat] = "new state" - :param key: key or slice - :param value: new state - """ - if isinstance(key, slice): - self.set_state(key.start, key.stop, value) - else: - self.set_state(key, key, value) - def __getitem__(self, key): - """ - Here you can use key or slice-key - - >>> storage[chat:user] - or - >>> storage[chat] - :param key: key or slice - :return: state - """ - if isinstance(key, slice): - return self.get_state(key.start, key.stop) - return self.get_state(key, key) - - def __delitem__(self, key): - """ - Reset state for user - :param key: - :return: - """ - if isinstance(key, slice): - self.del_state(key.start, key.stop) - else: - self.del_state(key, key) - - def __iter__(self): - yield from self.all_states() - - -class StateStorage(BaseStorage): +class MemoryStorage(BaseStorage): """ Simple in-memory state storage Based on builtin dict """ - def __init__(self): - self.storage = {} + def __init__(self, data=None): + if data is None: + data = {} + self.data = data def _prepare(self, chat, user): """ @@ -192,31 +152,42 @@ class StateStorage(BaseStorage): """ result = False - if chat not in self.storage: - self.storage[chat] = {} + chat = str(chat) + user = str(user) + + if chat not in self.data: + self.data[chat] = {} result = True - if user not in self.storage[chat]: - self.storage[chat][user] = {'state': None, 'data': {}} + if user not in self.data[chat]: + self.data[chat][user] = {'state': None, 'data': {}} result = True return result def set_state(self, chat, user, state): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - self.storage[chat][user]['state'] = self._prepare_state_name(state) + self.data[chat][user]['state'] = self._prepare_state_name(state) def get_state(self, chat, user): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - return self.storage[chat][user]['state'] + return self.data[chat][user]['state'] def del_state(self, chat, user): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - if self[chat:user] is not None: - self.storage[chat][user]['state'] = {'state': None, 'data': {}} + self.data[chat][user] = {'state': None, 'data': {}} def all_states(self, chat=None, user=None, state=None): - for chat_id, chat in self.storage.items(): + for chat_id, chat in self.data.items(): if chat is not None and chat != chat_id: continue for user_id, user_state in chat.items(): @@ -227,24 +198,95 @@ class StateStorage(BaseStorage): yield chat_id, user_id, user_state def set_value(self, chat, user, key, value): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - self.storage[chat][user]['data'][key] = value + self.data[chat][user]['data'][key] = value def del_value(self, chat, user, key): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - del self.storage[chat][user]['data'][key] + del self.data[chat][user]['data'][key] def get_data(self, chat, user): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - return self.storage[chat][user]['data'] + return self.data[chat][user]['data'] def update_data(self, chat, user, data): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - self.storage[chat][user]['data'].update(data) + self.data[chat][user]['data'].update(data) def clear_data(self, chat, user, key): + chat = str(chat) + user = str(user) + self._prepare(chat, user) - self.storage[chat][user]['data'].clear() + self.data[chat][user]['data'].clear() + + +class FileStorage(MemoryStorage): + """ + File-like storage for states. + """ + + def __init__(self, filename): + self.filename = filename + super(FileStorage, self).__init__(self.load(filename)) + + @staticmethod + def load(filename): + """ + Load data from file + + :param filename: + :return: dict + """ + if os.path.isfile(filename): + with open(filename, 'r') as file: + return json.load(file) + return {} + + def save(self): + """ + Write states to file + + :return: + """ + with open(self.filename, 'w') as file: + json.dump(self.data, file, indent=2) + + def set_state(self, chat, user, state): + super(FileStorage, self).set_state(chat, user, state) + self.save() + + def del_state(self, chat, user): + super(FileStorage, self).del_state(chat, user) + self.save() + + def set_value(self, chat, user, key, value): + super(FileStorage, self).set_value(chat, user, key, value) + self.save() + + def del_value(self, chat, user, key): + super(FileStorage, self).del_value(chat, user, key) + self.save() + + def update_data(self, chat, user, data): + super(FileStorage, self).update_data(chat, user, data) + self.save() + + def clear_data(self, chat, user, key): + super(FileStorage, self).clear_data(chat, user, key) + self.save() class Controller: @@ -267,7 +309,7 @@ class Controller: :param value: :return: """ - self._state_machine[self._chat:self._user] = value + self._state_machine.set_state(self._chat, self._user, value) def get_state(self): """ @@ -275,7 +317,7 @@ class Controller: :return: """ - return self._state_machine[self._chat:self._user] + return self._state_machine.get_state(self._chat, self._user) def clear(self): """ @@ -283,7 +325,7 @@ class Controller: :return: """ - del self._state_machine[self._chat:self._user] + self._state_machine.del_state(self._chat, self._user) def get(self, key, default=None): """ @@ -363,7 +405,7 @@ class StateMachine: def __init__(self, dispatcher, states, storage=None): if storage is None: - storage = StateStorage() + storage = MemoryStorage() self.steps = self._prepare_states(states) self.storage = storage @@ -393,7 +435,7 @@ class StateMachine: :return: """ log.debug(f"Set state for {chat}:{user} to '{state}'") - self.storage[chat:user] = state + self.storage.set_state(chat, user, state) def get_state(self, chat, user): """ @@ -402,7 +444,7 @@ class StateMachine: :param user: :return: """ - return self.storage[chat:user] + return self.storage.get_state(chat, user) def del_state(self, chat, user): """ @@ -412,7 +454,7 @@ class StateMachine: :return: """ log.debug(f"Reset state for {chat}:{user}") - del self.storage[chat:user] + self.storage.del_state(chat, user) async def process_message(self, message): """ @@ -436,43 +478,3 @@ class StateMachine: callback = self.steps[state] controller = Controller(self, chat_id, from_user_id, state) await callback(message, controller) - - def __setitem__(self, key, value): - """ - Here you can use key or slice-key - - >>> state[chat:user] = "new state" - or - >>> state[chat] = "new state" - :param key: key or slice - :param value: new state - """ - if isinstance(key, slice): - self.set_state(key.start, key.stop, value) - else: - self.set_state(key, key, value) - - def __getitem__(self, key): - """ - Here you can use key or slice-key - - >>> state[chat:user] - or - >>> state[chat] - :param key: key or slice - :return: state - """ - if isinstance(key, slice): - return self.get_state(key.start, key.stop) - return self.get_state(key, key) - - def __delitem__(self, key): - """ - Reset user state - :param key: - :return: - """ - if isinstance(key, slice): - self.del_state(key.start, key.stop) - else: - self.del_state(key, key)