Improve FSM.

This commit is contained in:
Alex Root Junior 2018-06-27 06:50:08 +03:00
parent fe6ae4863a
commit ac50db075b
5 changed files with 152 additions and 52 deletions

View file

@ -1,6 +1,6 @@
import typing import typing
from ...dispatcher import BaseStorage from ...dispatcher.storage import BaseStorage
class MemoryStorage(BaseStorage): class MemoryStorage(BaseStorage):
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
chat, user = self.check_address(chat=chat, user=user) chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user) user = self._get_user(chat, user)
if data is None: if data is None:
data = [] data = {}
user['data'].update(data, **kwargs) user['data'].update(data, **kwargs)
async def set_state(self, *, async def set_state(self, *,

View file

@ -15,7 +15,6 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
from .webhook import BaseResponse from .webhook import BaseResponse
from .. import types from .. import types
from ..bot import Bot, bot from ..bot import Bot, bot
from ..types.message import ContentType
from ..utils.exceptions import TelegramAPIError, Throttled from ..utils.exceptions import TelegramAPIError, Throttled
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -509,6 +508,7 @@ class Dispatcher:
:param kwargs: :param kwargs:
:return: decorated function :return: decorated function
""" """
def decorator(callback): def decorator(callback):
self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp, self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp,
content_types=content_types, state=state, run_task=run_task, **kwargs) content_types=content_types, state=state, run_task=run_task, **kwargs)
@ -699,6 +699,7 @@ class Dispatcher:
:param run_task: run callback in task (no wait results) :param run_task: run callback in task (no wait results)
:param kwargs: :param kwargs:
""" """
def decorator(callback): def decorator(callback):
self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
return callback return callback
@ -744,6 +745,7 @@ class Dispatcher:
:param run_task: run callback in task (no wait results) :param run_task: run callback in task (no wait results)
:param kwargs: :param kwargs:
""" """
def decorator(callback): def decorator(callback):
self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
return callback return callback

View file

@ -1,5 +1,4 @@
import re import re
import typing
from _contextvars import ContextVar from _contextvars import ContextVar
from aiogram import types from aiogram import types
@ -117,16 +116,26 @@ class StateFilter(BaseFilter):
ctx_state = ContextVar('user_state') ctx_state = ContextVar('user_state')
def __init__(self, dispatcher, state): def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State
super().__init__(dispatcher) super().__init__(dispatcher)
if isinstance(state, str) or state is None: states = []
state = (state,) if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
self.state = state state = [state, ]
for item in state:
if isinstance(item, State):
states.append(item.state)
elif hasattr(item, 'state_names'): # issubclass() cannot be used in this place
states.extend(item.state_names)
else:
states.append(item)
self.states = states
def get_target(self, obj): def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj): async def check(self, obj):
if '*' in self.state: if '*' in self.states:
return {'state': self.dispatcher.current_state()} return {'state': self.dispatcher.current_state()}
try: try:
@ -137,11 +146,11 @@ class StateFilter(BaseFilter):
if chat or user: if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user) state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state) self.ctx_state.set(state)
if state in self.state: if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state} return {'state': self.dispatcher.current_state(), 'raw_state': state}
else: else:
if state in self.state: if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state} return {'state': self.dispatcher.current_state(), 'raw_state': state}
return False return False

View file

@ -0,0 +1,95 @@
from ..dispatcher import Dispatcher
class State:
def __init__(self, state=None):
self.state = state
def __set_name__(self, owner, name):
if self.state is None:
self.state = owner.__name__ + ':' + name
def __str__(self):
return f"<State '{self.state}>'"
__repr__ = __str__
async def set(self):
state = Dispatcher.current().current_state()
await state.set_state(self.state)
class MetaStatesGroup(type):
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super(MetaStatesGroup, mcs).__new__(mcs, name, bases, namespace)
states = []
for name, prop in ((name, prop) for name, prop in namespace.items() if isinstance(prop, State)):
states.append(prop)
cls._states = tuple(states)
cls._state_names = tuple(state.state for state in states)
return cls
@property
def states(cls) -> tuple:
return cls._states
@property
def state_names(cls) -> tuple:
return cls._state_names
class StatesGroup(metaclass=MetaStatesGroup):
@classmethod
async def next(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
next_step = cls.state_names.index(state_name) + 1
except ValueError:
next_step = 0
try:
next_state_name = cls.states[next_step].state
except IndexError:
next_state_name = None
await state.set_state(next_state_name)
return next_state_name
@classmethod
async def previous(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
previous_step = cls.state_names.index(state_name) - 1
except ValueError:
previous_step = 0
if previous_step < 0:
previous_state_name = None
else:
previous_state_name = cls.states[previous_step].state
await state.set_state(previous_state_name)
return previous_state_name
@classmethod
async def first(cls) -> str:
state = Dispatcher.current().current_state()
first_step_name = cls.states[0].state
await state.set_state(first_step_name)
return first_step_name
@classmethod
async def last(cls) -> str:
state = Dispatcher.current().current_state()
last_step_name = cls.states[-1].state
await state.set_state(last_step_name)
return last_step_name

View file

@ -1,11 +1,13 @@
import asyncio import asyncio
from typing import Optional
from aiogram import Bot, types import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher from aiogram.dispatcher import FSMContext
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.types import ParseMode from aiogram.types import ParseMode
from aiogram.utils import executor from aiogram.utils import executor
from aiogram.utils.markdown import text, bold
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
storage = MemoryStorage() storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage) dp = Dispatcher(bot, storage=storage)
# States # States
AGE = 'process_age' class Form(StatesGroup):
NAME = 'process_name' name = State() # Will be represented in storage as 'Form:name'
GENDER = 'process_gender' age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start']) @dp.message_handler(commands=['start'])
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
""" """
Conversation's entry point Conversation's entry point
""" """
# Get current state # Set state
state = dp.current_state(chat=message.chat.id, user=message.from_user.id) await Form.name.set()
# Update user's state
await state.set_state(NAME)
await message.reply("Hi there! What's your name?") await message.reply("Hi there! What's your name?")
# You can use state '*' if you need to handle all states # You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel']) @dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel') @dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message): async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
""" """
Allow user to cancel any action Allow user to cancel any action
""" """
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: if raw_state is None:
# Ignore command if user is not in any (defined) state return
if await state.get_state() is None:
return
# Otherwise cancel state and inform user about it # Cancel state and inform user about it
# And remove keyboard (just in case) await state.finish()
await state.reset_state(with_data=True) # And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=NAME) @dp.message_handler(state=Form.name)
async def process_name(message: types.Message): async def process_name(message: types.Message, state: FSMContext):
""" """
Process user name Process user name
""" """
# Save name to storage and go to next step await Form.next()
# You can use context manager await state.update_data(name=message.text)
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
await state.update_data(name=message.text)
await state.set_state(AGE)
await message.reply("How old are you?") await message.reply("How old are you?")
# Check age. Age gotta be digit # Check age. Age gotta be digit
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit()) @dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def failed_process_age(message: types.Message): async def failed_process_age(message: types.Message):
""" """
If age is invalid If age is invalid
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit()) @dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message): async def process_age(message: types.Message, state: FSMContext):
# Update state and data # Update state and data
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: await Form.next()
await state.set_state(GENDER) await state.update_data(age=int(message.text))
await state.update_data(age=int(message.text))
# Configure ReplyKeyboardMarkup # Configure ReplyKeyboardMarkup
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
await message.reply("What is your gender?", reply_markup=markup) await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"]) @dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def failed_process_gender(message: types.Message): async def failed_process_gender(message: types.Message):
""" """
In this example gender has to be one of: Male, Female, Other. In this example gender has to be one of: Male, Female, Other.
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
return await message.reply("Bad gender name. Choose you gender from keyboard.") return await message.reply("Bad gender name. Choose you gender from keyboard.")
@dp.message_handler(state=GENDER) @dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message): async def process_gender(message: types.Message, state: FSMContext):
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
data = await state.get_data() data = await state.get_data()
data['gender'] = message.text data['gender'] = message.text
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
markup = types.ReplyKeyboardRemove() markup = types.ReplyKeyboardRemove()
# And send message # And send message
await bot.send_message(message.chat.id, text( await bot.send_message(message.chat.id, md.text(
text('Hi! Nice to meet you,', bold(data['name'])), md.text('Hi! Nice to meet you,', md.bold(data['name'])),
text('Age:', data['age']), md.text('Age:', data['age']),
text('Gender:', data['gender']), md.text('Gender:', data['gender']),
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
# Finish conversation # Finish conversation