mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 01:15:31 +00:00
Improve FSM.
This commit is contained in:
parent
fe6ae4863a
commit
ac50db075b
5 changed files with 152 additions and 52 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...dispatcher import BaseStorage
|
from ...dispatcher.storage import BaseStorage
|
||||||
|
|
||||||
|
|
||||||
class MemoryStorage(BaseStorage):
|
class MemoryStorage(BaseStorage):
|
||||||
|
|
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
|
||||||
chat, user = self.check_address(chat=chat, user=user)
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
user = self._get_user(chat, user)
|
user = self._get_user(chat, user)
|
||||||
if data is None:
|
if data is None:
|
||||||
data = []
|
data = {}
|
||||||
user['data'].update(data, **kwargs)
|
user['data'].update(data, **kwargs)
|
||||||
|
|
||||||
async def set_state(self, *,
|
async def set_state(self, *,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
|
||||||
from .webhook import BaseResponse
|
from .webhook import BaseResponse
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..bot import Bot, bot
|
from ..bot import Bot, bot
|
||||||
from ..types.message import ContentType
|
|
||||||
from ..utils.exceptions import TelegramAPIError, Throttled
|
from ..utils.exceptions import TelegramAPIError, Throttled
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -509,6 +508,7 @@ class Dispatcher:
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return: decorated function
|
:return: decorated function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp,
|
self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp,
|
||||||
content_types=content_types, state=state, run_task=run_task, **kwargs)
|
content_types=content_types, state=state, run_task=run_task, **kwargs)
|
||||||
|
|
@ -699,6 +699,7 @@ class Dispatcher:
|
||||||
:param run_task: run callback in task (no wait results)
|
:param run_task: run callback in task (no wait results)
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
|
self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
|
||||||
return callback
|
return callback
|
||||||
|
|
@ -744,6 +745,7 @@ class Dispatcher:
|
||||||
:param run_task: run callback in task (no wait results)
|
:param run_task: run callback in task (no wait results)
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
|
self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
|
||||||
return callback
|
return callback
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import re
|
import re
|
||||||
import typing
|
|
||||||
from _contextvars import ContextVar
|
from _contextvars import ContextVar
|
||||||
|
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
|
|
@ -117,16 +116,26 @@ class StateFilter(BaseFilter):
|
||||||
ctx_state = ContextVar('user_state')
|
ctx_state = ContextVar('user_state')
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
|
from aiogram.dispatcher.filters.state import State
|
||||||
|
|
||||||
super().__init__(dispatcher)
|
super().__init__(dispatcher)
|
||||||
if isinstance(state, str) or state is None:
|
states = []
|
||||||
state = (state,)
|
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||||
self.state = state
|
state = [state, ]
|
||||||
|
for item in state:
|
||||||
|
if isinstance(item, State):
|
||||||
|
states.append(item.state)
|
||||||
|
elif hasattr(item, 'state_names'): # issubclass() cannot be used in this place
|
||||||
|
states.extend(item.state_names)
|
||||||
|
else:
|
||||||
|
states.append(item)
|
||||||
|
self.states = states
|
||||||
|
|
||||||
def get_target(self, obj):
|
def get_target(self, obj):
|
||||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||||
|
|
||||||
async def check(self, obj):
|
async def check(self, obj):
|
||||||
if '*' in self.state:
|
if '*' in self.states:
|
||||||
return {'state': self.dispatcher.current_state()}
|
return {'state': self.dispatcher.current_state()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -137,11 +146,11 @@ class StateFilter(BaseFilter):
|
||||||
if chat or user:
|
if chat or user:
|
||||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
||||||
self.ctx_state.set(state)
|
self.ctx_state.set(state)
|
||||||
if state in self.state:
|
if state in self.states:
|
||||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if state in self.state:
|
if state in self.states:
|
||||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
95
aiogram/dispatcher/filters/state.py
Normal file
95
aiogram/dispatcher/filters/state.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
from ..dispatcher import Dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
def __init__(self, state=None):
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
if self.state is None:
|
||||||
|
self.state = owner.__name__ + ':' + name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<State '{self.state}>'"
|
||||||
|
|
||||||
|
__repr__ = __str__
|
||||||
|
|
||||||
|
async def set(self):
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
await state.set_state(self.state)
|
||||||
|
|
||||||
|
|
||||||
|
class MetaStatesGroup(type):
|
||||||
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||||
|
cls = super(MetaStatesGroup, mcs).__new__(mcs, name, bases, namespace)
|
||||||
|
|
||||||
|
states = []
|
||||||
|
for name, prop in ((name, prop) for name, prop in namespace.items() if isinstance(prop, State)):
|
||||||
|
states.append(prop)
|
||||||
|
|
||||||
|
cls._states = tuple(states)
|
||||||
|
cls._state_names = tuple(state.state for state in states)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def states(cls) -> tuple:
|
||||||
|
return cls._states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_names(cls) -> tuple:
|
||||||
|
return cls._state_names
|
||||||
|
|
||||||
|
|
||||||
|
class StatesGroup(metaclass=MetaStatesGroup):
|
||||||
|
@classmethod
|
||||||
|
async def next(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
state_name = await state.get_state()
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_step = cls.state_names.index(state_name) + 1
|
||||||
|
except ValueError:
|
||||||
|
next_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_state_name = cls.states[next_step].state
|
||||||
|
except IndexError:
|
||||||
|
next_state_name = None
|
||||||
|
|
||||||
|
await state.set_state(next_state_name)
|
||||||
|
return next_state_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def previous(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
state_name = await state.get_state()
|
||||||
|
|
||||||
|
try:
|
||||||
|
previous_step = cls.state_names.index(state_name) - 1
|
||||||
|
except ValueError:
|
||||||
|
previous_step = 0
|
||||||
|
|
||||||
|
if previous_step < 0:
|
||||||
|
previous_state_name = None
|
||||||
|
else:
|
||||||
|
previous_state_name = cls.states[previous_step].state
|
||||||
|
|
||||||
|
await state.set_state(previous_state_name)
|
||||||
|
return previous_state_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def first(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
first_step_name = cls.states[0].state
|
||||||
|
|
||||||
|
await state.set_state(first_step_name)
|
||||||
|
return first_step_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def last(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
last_step_name = cls.states[-1].state
|
||||||
|
|
||||||
|
await state.set_state(last_step_name)
|
||||||
|
return last_step_name
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from aiogram import Bot, types
|
import aiogram.utils.markdown as md
|
||||||
|
from aiogram import Bot, Dispatcher, types
|
||||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
from aiogram.dispatcher import Dispatcher
|
from aiogram.dispatcher import FSMContext
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
from aiogram.types import ParseMode
|
from aiogram.types import ParseMode
|
||||||
from aiogram.utils import executor
|
from aiogram.utils import executor
|
||||||
from aiogram.utils.markdown import text, bold
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
storage = MemoryStorage()
|
storage = MemoryStorage()
|
||||||
dp = Dispatcher(bot, storage=storage)
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
|
||||||
# States
|
# States
|
||||||
AGE = 'process_age'
|
class Form(StatesGroup):
|
||||||
NAME = 'process_name'
|
name = State() # Will be represented in storage as 'Form:name'
|
||||||
GENDER = 'process_gender'
|
age = State() # Will be represented in storage as 'Form:age'
|
||||||
|
gender = State() # Will be represented in storage as 'Form:gender'
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(commands=['start'])
|
@dp.message_handler(commands=['start'])
|
||||||
|
|
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
|
||||||
"""
|
"""
|
||||||
Conversation's entry point
|
Conversation's entry point
|
||||||
"""
|
"""
|
||||||
# Get current state
|
# Set state
|
||||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
await Form.name.set()
|
||||||
# Update user's state
|
|
||||||
await state.set_state(NAME)
|
|
||||||
|
|
||||||
await message.reply("Hi there! What's your name?")
|
await message.reply("Hi there! What's your name?")
|
||||||
|
|
||||||
|
|
||||||
# You can use state '*' if you need to handle all states
|
# You can use state '*' if you need to handle all states
|
||||||
@dp.message_handler(state='*', commands=['cancel'])
|
@dp.message_handler(state='*', commands=['cancel'])
|
||||||
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel')
|
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||||
async def cancel_handler(message: types.Message):
|
async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Allow user to cancel any action
|
Allow user to cancel any action
|
||||||
"""
|
"""
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
if raw_state is None:
|
||||||
# Ignore command if user is not in any (defined) state
|
return
|
||||||
if await state.get_state() is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Otherwise cancel state and inform user about it
|
# Cancel state and inform user about it
|
||||||
# And remove keyboard (just in case)
|
await state.finish()
|
||||||
await state.reset_state(with_data=True)
|
# And remove keyboard (just in case)
|
||||||
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=NAME)
|
@dp.message_handler(state=Form.name)
|
||||||
async def process_name(message: types.Message):
|
async def process_name(message: types.Message, state: FSMContext):
|
||||||
"""
|
"""
|
||||||
Process user name
|
Process user name
|
||||||
"""
|
"""
|
||||||
# Save name to storage and go to next step
|
await Form.next()
|
||||||
# You can use context manager
|
await state.update_data(name=message.text)
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
|
||||||
await state.update_data(name=message.text)
|
|
||||||
await state.set_state(AGE)
|
|
||||||
|
|
||||||
await message.reply("How old are you?")
|
await message.reply("How old are you?")
|
||||||
|
|
||||||
|
|
||||||
# Check age. Age gotta be digit
|
# Check age. Age gotta be digit
|
||||||
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit())
|
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||||
async def failed_process_age(message: types.Message):
|
async def failed_process_age(message: types.Message):
|
||||||
"""
|
"""
|
||||||
If age is invalid
|
If age is invalid
|
||||||
|
|
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
|
||||||
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit())
|
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||||
async def process_age(message: types.Message):
|
async def process_age(message: types.Message, state: FSMContext):
|
||||||
# Update state and data
|
# Update state and data
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
await Form.next()
|
||||||
await state.set_state(GENDER)
|
await state.update_data(age=int(message.text))
|
||||||
await state.update_data(age=int(message.text))
|
|
||||||
|
|
||||||
# Configure ReplyKeyboardMarkup
|
# Configure ReplyKeyboardMarkup
|
||||||
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||||
|
|
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
|
||||||
await message.reply("What is your gender?", reply_markup=markup)
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"])
|
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||||
async def failed_process_gender(message: types.Message):
|
async def failed_process_gender(message: types.Message):
|
||||||
"""
|
"""
|
||||||
In this example gender has to be one of: Male, Female, Other.
|
In this example gender has to be one of: Male, Female, Other.
|
||||||
|
|
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
|
||||||
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=GENDER)
|
@dp.message_handler(state=Form.gender)
|
||||||
async def process_gender(message: types.Message):
|
async def process_gender(message: types.Message, state: FSMContext):
|
||||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
|
||||||
|
|
||||||
data = await state.get_data()
|
data = await state.get_data()
|
||||||
data['gender'] = message.text
|
data['gender'] = message.text
|
||||||
|
|
||||||
|
|
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
|
||||||
markup = types.ReplyKeyboardRemove()
|
markup = types.ReplyKeyboardRemove()
|
||||||
|
|
||||||
# And send message
|
# And send message
|
||||||
await bot.send_message(message.chat.id, text(
|
await bot.send_message(message.chat.id, md.text(
|
||||||
text('Hi! Nice to meet you,', bold(data['name'])),
|
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
|
||||||
text('Age:', data['age']),
|
md.text('Age:', data['age']),
|
||||||
text('Gender:', data['gender']),
|
md.text('Gender:', data['gender']),
|
||||||
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
|
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
|
||||||
|
|
||||||
# Finish conversation
|
# Finish conversation
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue