From 6bbe330fdd5fb708fcb4c475d72a305dc4506cd1 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Tue, 20 Jun 2017 04:13:12 +0300 Subject: [PATCH] State machine. --- aiogram/dispatcher/handler.py | 25 ++- aiogram/dispatcher/state.py | 312 ++++++++++++++++++++++++++++++++++ examples/state_machine.py | 72 ++++++++ 3 files changed, 405 insertions(+), 4 deletions(-) create mode 100644 aiogram/dispatcher/state.py create mode 100644 examples/state_machine.py diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index a1ff7b12..b7d7efe8 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -4,6 +4,14 @@ from .filters import check_filters, CancelFilter from .. import types +class SkipHandler(BaseException): + pass + + +class CancelHandler(BaseException): + pass + + class Handler: def __init__(self, dispatcher, once=True): self.dispatcher = dispatcher @@ -11,10 +19,14 @@ class Handler: self.handlers = [] - def register(self, handler, filters=None): + def register(self, handler, filters=None, index=None): if filters and not isinstance(filters, (list, tuple, set)): filters = [filters] - self.handlers.append((filters, handler)) + record = (filters, handler) + if index is None: + self.handlers.append(record) + else: + self.handlers.insert(index, record) def unregister(self, handler): for handler_with_filters in self.handlers: @@ -27,8 +39,13 @@ class Handler: async def notify(self, *args, **kwargs): for filters, handler in self.handlers: if await check_filters(filters, args, kwargs): - await handler(*args, **kwargs) - if self.once: + try: + await handler(*args, **kwargs) + if self.once: + break + except SkipHandler: + continue + except CancelHandler: break diff --git a/aiogram/dispatcher/state.py b/aiogram/dispatcher/state.py new file mode 100644 index 00000000..28793634 --- /dev/null +++ b/aiogram/dispatcher/state.py @@ -0,0 +1,312 @@ +import logging + +from aiogram.dispatcher.handler import SkipHandler + +log = logging.getLogger('aiogram.StateMachine') + + +class BaseStorage: + """ + Skeleton for states storage + """ + + @staticmethod + def _prepare_state_name(value): + if callable(value): + if hasattr(value, '__name__'): + return value.__name__ + else: + return value.__class__.__name__ + return value + + def set_state(self, chat, user, state): + """ + Set state + :param chat: chat_id + :param user: user_id + :param state: value + """ + raise NotImplementedError + + def get_state(self, chat, user): + """ + Get user state from + :param chat: + :param user: + :return: + """ + raise NotImplementedError + + def del_state(self, chat, user): + """ + Clear user state + :param chat: cha + :param user: + :return: + """ + raise NotImplementedError + + def all_states(self, chat=None, user=None, state=None): + """ + Yield all states (Can use filters) + :param chat: + :param user: + :return: + """ + raise NotImplementedError + + def __setitem__(self, key, value): + """ + Here you can use key or slice-key + + >>> storage[chat:user] = "new state" + or + >>> storage[chat] = "new state" + :param key: key or slice + :param value: new state + """ + if isinstance(key, slice): + self.set_state(key.start, key.stop, value) + else: + self.set_state(key, key, value) + + def __getitem__(self, key): + """ + Here you can use key or slice-key + + >>> storage[chat:user] + or + >>> storage[chat] + :param key: key or slice + :return: state + """ + if isinstance(key, slice): + return self.get_state(key.start, key.stop) + return self.get_state(key, key) + + def __delitem__(self, key): + if isinstance(key, slice): + self.del_state(key.start, key.stop) + else: + self.del_state(key, key) + + def __iter__(self): + yield from self.all_states() + + +class StateStorage(BaseStorage): + """ + Simple in-memory state storage + Based on builtin dict + """ + + def __init__(self): + self.storage = {} + + def _prepare(self, chat, user): + """ + Add chat and user to storage if they are not exist + :param chat: + :param user: + :return: + """ + result = False + + if chat not in self.storage: + self.storage[chat] = {} + result = True + + if user not in self.storage[chat]: + self.storage[chat][user] = None + result = True + + return result + + def set_state(self, chat, user, state): + self._prepare(chat, user) + self.storage[chat][user] = self._prepare_state_name(state) + + def get_state(self, chat, user): + self._prepare(chat, user) + return self.storage[chat][user] + + def del_state(self, chat, user): + self._prepare(chat, user) + if self[chat:user] is not None: + self[chat:user] = None + + def all_states(self, chat=None, user=None, state=None): + for chat_id, chat in self.storage.items(): + if chat is not None and chat != chat_id: + continue + for user_id, user_state in chat.items(): + if user is not None and user != user_id: + continue + if state is not None and user_state == state: + continue + yield chat_id, user_id, user_state + + +class Controller: + """ + Storage controller + + Make easy access from callback's + """ + + def __init__(self, state_machine, chat, user, state): + self._state_machine = state_machine + self._chat = chat + self._user = user + self._state = state + + def set(self, value): + """ + Set state + :param value: + :return: + """ + self._state_machine[self._chat:self._user] = value + + def get(self): + """ + Get current state + :return: + """ + return self._state_machine[self._chat:self._user] + + def clear(self): + """ + Reset state + :return: + """ + del self._state_machine[self._chat:self._user] + + def __str__(self): + return f"{self._chat}:{self._user} - {self._state}" + + +class StateMachine: + """ + Manage state + """ + + def __init__(self, dispatcher, states, storage=None): + if storage is None: + storage = StateStorage() + + self.steps = self._prepare_states(states) + self.storage = storage + + dispatcher.message_handlers.register(self.process_message, index=0) + + @staticmethod + def _prepare_states(states): + if isinstance(states, dict): + return states + elif isinstance(states, (list, tuple, set)): + prepared_states = {} + for state in states: + if not callable(state): + raise TypeError('State must be an callable') + state_name = state.__name__ + prepared_states[state_name] = state + return prepared_states + raise TypeError('States must be an dict or list!') + + def set_state(self, chat, user, state): + """ + Save state to storage + :param chat: + :param user: + :param state: + :return: + """ + log.debug(f"Set state for {chat}:{user} to '{state}'") + self.storage[chat:user] = state + + def get_state(self, chat, user): + """ + Get state from storage + :param chat: + :param user: + :return: + """ + return self.storage[chat:user] + + def del_state(self, chat, user): + """ + Clear user state + :param chat: + :param user: + :return: + """ + log.debug(f"Reset state for {chat}:{user}") + del self.storage[chat:user] + + async def process_message(self, message): + """ + Read message and process it + :param message: + :return: + """ + chat_id = message.chat.id + from_user_id = message.from_user.id + + state = self.get_state(chat_id, from_user_id) + if state is None: + raise SkipHandler() + + if state not in self.steps: + log.warning(f"Found unknown state '{state}' for {chat_id}:{from_user_id}. Condition will be reset.") + self.del_state(chat_id, from_user_id) + raise SkipHandler() + + callback = self.steps[state] + controller = Controller(self, chat_id, from_user_id, state) + log.debug(f"Process state for {chat_id}:{from_user_id} - '{state}'") + result = await callback(message, controller) + # if result is True: + # controller.clear() + # elif isinstance(result, str): + # controller.set(result) + + def __setitem__(self, key, value): + """ + Here you can use key or slice-key + + >>> state[chat:user] = "new state" + or + >>> state[chat] = "new state" + :param key: key or slice + :param value: new state + """ + if isinstance(key, slice): + self.set_state(key.start, key.stop, value) + else: + self.set_state(key, key, value) + + def __getitem__(self, key): + """ + Here you can use key or slice-key + + >>> state[chat:user] + or + >>> state[chat] + :param key: key or slice + :return: state + """ + if isinstance(key, slice): + return self.get_state(key.start, key.stop) + return self.get_state(key, key) + + def __delitem__(self, key): + """ + Reset user state + :param key: + :return: + """ + if isinstance(key, slice): + self.del_state(key.start, key.stop) + else: + self.del_state(key, key) diff --git a/examples/state_machine.py b/examples/state_machine.py new file mode 100644 index 00000000..a593b48e --- /dev/null +++ b/examples/state_machine.py @@ -0,0 +1,72 @@ +import asyncio +import logging + +from aiogram import Bot, types +from aiogram.dispatcher import Dispatcher +from aiogram.dispatcher.state import StateMachine + +API_TOKEN = 'BOT TOKEN HERE' +API_TOKEN = '380294876:AAFbdYYgq1hBi9hQDcxD3bj8QCNnVec5aHk' + +logging.basicConfig(level=logging.DEBUG) + +loop = asyncio.get_event_loop() +bot = Bot(token=API_TOKEN, loop=loop) +dp = Dispatcher(bot) + +users = {} + + +@dp.message_handler(commands=['start']) +async def send_welcome(message: types.Message): + await message.reply("Hi there! What's your name?") + state.set_state(message.chat.id, message.from_user.id, "name") + + +async def process_name(message, controller): + users[message.from_user.id] = {"name": message.text} + + await message.reply("How old are you?") + + controller.set('age') + + +async def process_age(message, controller): + if not message.text.isdigit(): + return await message.reply("Age should be a number.\nHow old are you?") + + users[message.from_user.id].update({"age": int(message.text)}) + + markup = types.ReplyKeyboardMarkup() + markup.add("Male", "Female") + markup.add("Other") + await message.reply("What is your gender?", reply_markup=markup) + controller.set("sex") + + +async def process_sex(message, controller): + if message.text not in ["Male", "Female", "Other"]: + return await message.reply("Bad gender name. Choose you gender from keyboard.") + + users[message.from_user.id].update({"sex": message.text}) + controller.clear() + + user = users[message.from_user.id] + + markup = types.ReplyKeyboardRemove() + await bot.send_message(message.chat.id, + f"Hi!\nNice to meet you, {user['name']}.\nAge: {user['age']}\nSex: {user['sex']}", + reply_markup=markup) + + +state = StateMachine(dp, { + "name": process_name, + "age": process_age, + "sex": process_sex +}) + +if __name__ == '__main__': + try: + loop.run_until_complete(dp.start_pooling()) + except KeyboardInterrupt: + loop.stop()