From ac50db075b1b1d9b4ea26b1a4fe2c57039e3ec10 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Wed, 27 Jun 2018 06:50:08 +0300 Subject: [PATCH] Improve FSM. --- aiogram/contrib/fsm_storage/memory.py | 4 +- aiogram/dispatcher/dispatcher.py | 4 +- aiogram/dispatcher/filters/builtin.py | 23 ++++-- aiogram/dispatcher/filters/state.py | 95 ++++++++++++++++++++++++ examples/finite_state_machine_example.py | 78 +++++++++---------- 5 files changed, 152 insertions(+), 52 deletions(-) create mode 100644 aiogram/dispatcher/filters/state.py diff --git a/aiogram/contrib/fsm_storage/memory.py b/aiogram/contrib/fsm_storage/memory.py index f8670ec4..d526e90e 100644 --- a/aiogram/contrib/fsm_storage/memory.py +++ b/aiogram/contrib/fsm_storage/memory.py @@ -1,6 +1,6 @@ import typing -from ...dispatcher import BaseStorage +from ...dispatcher.storage import BaseStorage class MemoryStorage(BaseStorage): @@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage): chat, user = self.check_address(chat=chat, user=user) user = self._get_user(chat, user) if data is None: - data = [] + data = {} user['data'].update(data, **kwargs) async def set_state(self, *, diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index c90edc8d..2ccb4f78 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -15,7 +15,6 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon from .webhook import BaseResponse from .. import types from ..bot import Bot, bot -from ..types.message import ContentType from ..utils.exceptions import TelegramAPIError, Throttled log = logging.getLogger(__name__) @@ -509,6 +508,7 @@ class Dispatcher: :param kwargs: :return: decorated function """ + def decorator(callback): self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp, content_types=content_types, state=state, run_task=run_task, **kwargs) @@ -699,6 +699,7 @@ class Dispatcher: :param run_task: run callback in task (no wait results) :param kwargs: """ + def decorator(callback): self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) return callback @@ -744,6 +745,7 @@ class Dispatcher: :param run_task: run callback in task (no wait results) :param kwargs: """ + def decorator(callback): self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) return callback diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 47f266e8..7fd3866f 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -1,5 +1,4 @@ import re -import typing from _contextvars import ContextVar from aiogram import types @@ -117,16 +116,26 @@ class StateFilter(BaseFilter): ctx_state = ContextVar('user_state') def __init__(self, dispatcher, state): + from aiogram.dispatcher.filters.state import State + super().__init__(dispatcher) - if isinstance(state, str) or state is None: - state = (state,) - self.state = state + states = [] + if not isinstance(state, (list, set, tuple, frozenset)) or state is None: + state = [state, ] + for item in state: + if isinstance(item, State): + states.append(item.state) + elif hasattr(item, 'state_names'): # issubclass() cannot be used in this place + states.extend(item.state_names) + else: + states.append(item) + self.states = states def get_target(self, obj): return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) async def check(self, obj): - if '*' in self.state: + if '*' in self.states: return {'state': self.dispatcher.current_state()} try: @@ -137,11 +146,11 @@ class StateFilter(BaseFilter): if chat or user: state = await self.dispatcher.storage.get_state(chat=chat, user=user) self.ctx_state.set(state) - if state in self.state: + if state in self.states: return {'state': self.dispatcher.current_state(), 'raw_state': state} else: - if state in self.state: + if state in self.states: return {'state': self.dispatcher.current_state(), 'raw_state': state} return False diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py new file mode 100644 index 00000000..52e431da --- /dev/null +++ b/aiogram/dispatcher/filters/state.py @@ -0,0 +1,95 @@ +from ..dispatcher import Dispatcher + + +class State: + def __init__(self, state=None): + self.state = state + + def __set_name__(self, owner, name): + if self.state is None: + self.state = owner.__name__ + ':' + name + + def __str__(self): + return f"'" + + __repr__ = __str__ + + async def set(self): + state = Dispatcher.current().current_state() + await state.set_state(self.state) + + +class MetaStatesGroup(type): + def __new__(mcs, name, bases, namespace, **kwargs): + cls = super(MetaStatesGroup, mcs).__new__(mcs, name, bases, namespace) + + states = [] + for name, prop in ((name, prop) for name, prop in namespace.items() if isinstance(prop, State)): + states.append(prop) + + cls._states = tuple(states) + cls._state_names = tuple(state.state for state in states) + + return cls + + @property + def states(cls) -> tuple: + return cls._states + + @property + def state_names(cls) -> tuple: + return cls._state_names + + +class StatesGroup(metaclass=MetaStatesGroup): + @classmethod + async def next(cls) -> str: + state = Dispatcher.current().current_state() + state_name = await state.get_state() + + try: + next_step = cls.state_names.index(state_name) + 1 + except ValueError: + next_step = 0 + + try: + next_state_name = cls.states[next_step].state + except IndexError: + next_state_name = None + + await state.set_state(next_state_name) + return next_state_name + + @classmethod + async def previous(cls) -> str: + state = Dispatcher.current().current_state() + state_name = await state.get_state() + + try: + previous_step = cls.state_names.index(state_name) - 1 + except ValueError: + previous_step = 0 + + if previous_step < 0: + previous_state_name = None + else: + previous_state_name = cls.states[previous_step].state + + await state.set_state(previous_state_name) + return previous_state_name + + @classmethod + async def first(cls) -> str: + state = Dispatcher.current().current_state() + first_step_name = cls.states[0].state + + await state.set_state(first_step_name) + return first_step_name + + @classmethod + async def last(cls) -> str: + state = Dispatcher.current().current_state() + last_step_name = cls.states[-1].state + + await state.set_state(last_step_name) + return last_step_name diff --git a/examples/finite_state_machine_example.py b/examples/finite_state_machine_example.py index 45755d9a..7a989e5b 100644 --- a/examples/finite_state_machine_example.py +++ b/examples/finite_state_machine_example.py @@ -1,11 +1,13 @@ import asyncio +from typing import Optional -from aiogram import Bot, types +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.dispatcher import Dispatcher +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters.state import State, StatesGroup from aiogram.types import ParseMode from aiogram.utils import executor -from aiogram.utils.markdown import text, bold API_TOKEN = 'BOT TOKEN HERE' @@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop) storage = MemoryStorage() dp = Dispatcher(bot, storage=storage) + # States -AGE = 'process_age' -NAME = 'process_name' -GENDER = 'process_gender' +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' @dp.message_handler(commands=['start']) @@ -28,48 +32,41 @@ async def cmd_start(message: types.Message): """ Conversation's entry point """ - # Get current state - state = dp.current_state(chat=message.chat.id, user=message.from_user.id) - # Update user's state - await state.set_state(NAME) + # Set state + await Form.name.set() await message.reply("Hi there! What's your name?") # You can use state '*' if you need to handle all states @dp.message_handler(state='*', commands=['cancel']) -@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel') -async def cancel_handler(message: types.Message): +@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*') +async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None): """ Allow user to cancel any action """ - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - # Ignore command if user is not in any (defined) state - if await state.get_state() is None: - return + if raw_state is None: + return - # Otherwise cancel state and inform user about it - # And remove keyboard (just in case) - await state.reset_state(with_data=True) - await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) + # Cancel state and inform user about it + await state.finish() + # And remove keyboard (just in case) + await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) -@dp.message_handler(state=NAME) -async def process_name(message: types.Message): +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state: FSMContext): """ Process user name """ - # Save name to storage and go to next step - # You can use context manager - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - await state.update_data(name=message.text) - await state.set_state(AGE) + await Form.next() + await state.update_data(name=message.text) await message.reply("How old are you?") # Check age. Age gotta be digit -@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit()) +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) async def failed_process_age(message: types.Message): """ If age is invalid @@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message): return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") -@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit()) -async def process_age(message: types.Message): +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state: FSMContext): # Update state and data - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - await state.set_state(GENDER) - await state.update_data(age=int(message.text)) + await Form.next() + await state.update_data(age=int(message.text)) # Configure ReplyKeyboardMarkup markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) @@ -92,7 +88,7 @@ async def process_age(message: types.Message): await message.reply("What is your gender?", reply_markup=markup) -@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"]) +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) async def failed_process_gender(message: types.Message): """ In this example gender has to be one of: Male, Female, Other. @@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message): return await message.reply("Bad gender name. Choose you gender from keyboard.") -@dp.message_handler(state=GENDER) -async def process_gender(message: types.Message): - state = dp.current_state(chat=message.chat.id, user=message.from_user.id) - +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state: FSMContext): data = await state.get_data() data['gender'] = message.text @@ -111,10 +105,10 @@ async def process_gender(message: types.Message): markup = types.ReplyKeyboardRemove() # And send message - await bot.send_message(message.chat.id, text( - text('Hi! Nice to meet you,', bold(data['name'])), - text('Age:', data['age']), - text('Gender:', data['gender']), + await bot.send_message(message.chat.id, md.text( + md.text('Hi! Nice to meet you,', md.bold(data['name'])), + md.text('Age:', data['age']), + md.text('Gender:', data['gender']), sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) # Finish conversation