mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-15 03:08:51 +00:00
Experimental: FSM Middleware.
This commit is contained in:
parent
5466f0c342
commit
acdcd1455f
3 changed files with 231 additions and 0 deletions
80
aiogram/contrib/middlewares/fsm.py
Normal file
80
aiogram/contrib/middlewares/fsm.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
import copy
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||||
|
from aiogram.dispatcher.storage import FSMContext
|
||||||
|
|
||||||
|
|
||||||
|
class FSMMiddleware(LifetimeControllerMiddleware):
|
||||||
|
skip_patterns = ['error', 'update']
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FSMMiddleware, self).__init__()
|
||||||
|
self._proxies = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
async def pre_process(self, obj, data, *args):
|
||||||
|
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
|
||||||
|
data['state_data'] = proxy
|
||||||
|
|
||||||
|
async def post_process(self, obj, data, *args):
|
||||||
|
proxy = data.get('state_data', None)
|
||||||
|
if isinstance(proxy, FSMSStorageProxy):
|
||||||
|
await proxy.save()
|
||||||
|
|
||||||
|
|
||||||
|
class FSMSStorageProxy(dict):
|
||||||
|
def __init__(self, fsm_context: FSMContext):
|
||||||
|
super(FSMSStorageProxy, self).__init__()
|
||||||
|
self.fsm_context = fsm_context
|
||||||
|
self._copy = {}
|
||||||
|
self._data = {}
|
||||||
|
self._state = None
|
||||||
|
self._is_dirty = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, fsm_context: FSMContext):
|
||||||
|
"""
|
||||||
|
:param fsm_context:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
proxy = cls(fsm_context)
|
||||||
|
await proxy.load()
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
async def load(self):
|
||||||
|
self.clear()
|
||||||
|
self._state = await self.fsm_context.get_state()
|
||||||
|
self.update(await self.fsm_context.get_data())
|
||||||
|
self._copy = copy.deepcopy(self)
|
||||||
|
self._is_dirty = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, value):
|
||||||
|
self._state = value
|
||||||
|
self._is_dirty = True
|
||||||
|
|
||||||
|
@state.deleter
|
||||||
|
def state(self):
|
||||||
|
self._state = None
|
||||||
|
self._is_dirty = True
|
||||||
|
|
||||||
|
async def save(self, force=False):
|
||||||
|
if self._copy != self or force:
|
||||||
|
await self.fsm_context.set_data(data=self)
|
||||||
|
if self._is_dirty or force:
|
||||||
|
await self.fsm_context.set_state(self.state)
|
||||||
|
self._is_dirty = False
|
||||||
|
self._copy = copy.deepcopy(self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = super(FSMSStorageProxy, self).__str__()
|
||||||
|
readable_state = f"'{self.state}'" if self.state else "''"
|
||||||
|
return f"<{self.__class__.__name__}(state={readable_state}, data={s})>"
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
del self.state
|
||||||
|
return super(FSMSStorageProxy, self).clear()
|
||||||
|
|
@ -101,3 +101,28 @@ class BaseMiddleware:
|
||||||
if not handler:
|
if not handler:
|
||||||
return None
|
return None
|
||||||
await handler(*args)
|
await handler(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class LifetimeControllerMiddleware(BaseMiddleware):
|
||||||
|
# TODO: Rename class
|
||||||
|
|
||||||
|
skip_patterns = None
|
||||||
|
|
||||||
|
async def pre_process(self, obj, data, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post_process(self, obj, data, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
if self.skip_patterns is not None and any(item in action for item in self.skip_patterns):
|
||||||
|
return False
|
||||||
|
|
||||||
|
obj, *args, data = args
|
||||||
|
if action.startswith('pre_process_'):
|
||||||
|
await self.pre_process(obj, data, *args)
|
||||||
|
elif action.startswith('post_process_'):
|
||||||
|
await self.post_process(obj, data, *args)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
||||||
126
examples/finite_state_machine_example_2.py
Normal file
126
examples/finite_state_machine_example_2.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
"""
|
||||||
|
This example is equals with 'finite_state_machine_example.py' but with FSM Middleware
|
||||||
|
|
||||||
|
Note that FSM Middleware implements the more simple methods for working with storage.
|
||||||
|
|
||||||
|
With that middleware all data from storage will be loaded before event will be processed
|
||||||
|
and data will be stored after processing the event.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import aiogram.utils.markdown as md
|
||||||
|
from aiogram import Bot, Dispatcher, types
|
||||||
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
|
from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
|
from aiogram.utils import executor
|
||||||
|
|
||||||
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
|
|
||||||
|
# For example use simple MemoryStorage for Dispatcher.
|
||||||
|
storage = MemoryStorage()
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
dp.middleware.setup(FSMMiddleware())
|
||||||
|
|
||||||
|
|
||||||
|
# States
|
||||||
|
class Form(StatesGroup):
|
||||||
|
name = State() # Will be represented in storage as 'Form:name'
|
||||||
|
age = State() # Will be represented in storage as 'Form:age'
|
||||||
|
gender = State() # Will be represented in storage as 'Form:gender'
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start'])
|
||||||
|
async def cmd_start(message: types.Message):
|
||||||
|
"""
|
||||||
|
Conversation's entry point
|
||||||
|
"""
|
||||||
|
# Set state
|
||||||
|
await Form.first()
|
||||||
|
|
||||||
|
await message.reply("Hi there! What's your name?")
|
||||||
|
|
||||||
|
|
||||||
|
# You can use state '*' if you need to handle all states
|
||||||
|
@dp.message_handler(state='*', commands=['cancel'])
|
||||||
|
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||||
|
async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
"""
|
||||||
|
Allow user to cancel any action
|
||||||
|
"""
|
||||||
|
if state_data.state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel state and inform user about it
|
||||||
|
del state_data.state
|
||||||
|
# And remove keyboard (just in case)
|
||||||
|
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(state=Form.name)
|
||||||
|
async def process_name(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
"""
|
||||||
|
Process user name
|
||||||
|
"""
|
||||||
|
state_data.state = Form.age
|
||||||
|
state_data['name'] = message.text
|
||||||
|
|
||||||
|
await message.reply("How old are you?")
|
||||||
|
|
||||||
|
|
||||||
|
# Check age. Age gotta be digit
|
||||||
|
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||||
|
async def failed_process_age(message: types.Message):
|
||||||
|
"""
|
||||||
|
If age is invalid
|
||||||
|
"""
|
||||||
|
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||||
|
async def process_age(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
# Update state and data
|
||||||
|
state_data.state = Form.gender
|
||||||
|
state_data['age'] = int(message.text)
|
||||||
|
|
||||||
|
# Configure ReplyKeyboardMarkup
|
||||||
|
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||||
|
markup.add("Male", "Female")
|
||||||
|
markup.add("Other")
|
||||||
|
|
||||||
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||||
|
async def failed_process_gender(message: types.Message):
|
||||||
|
"""
|
||||||
|
In this example gender has to be one of: Male, Female, Other.
|
||||||
|
"""
|
||||||
|
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(state=Form.gender)
|
||||||
|
async def process_gender(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
state_data['gender'] = message.text
|
||||||
|
|
||||||
|
# Remove keyboard
|
||||||
|
markup = types.ReplyKeyboardRemove()
|
||||||
|
|
||||||
|
# And send message
|
||||||
|
await bot.send_message(message.chat.id, md.text(
|
||||||
|
md.text('Hi! Nice to meet you,', md.bold(state_data['name'])),
|
||||||
|
md.text('Age:', state_data['age']),
|
||||||
|
md.text('Gender:', state_data['gender']),
|
||||||
|
sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN)
|
||||||
|
|
||||||
|
# Finish conversation
|
||||||
|
# WARNING! This method will destroy all data in storage for current user!
|
||||||
|
state_data.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue