mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 09:22:03 +00:00
State machine.
This commit is contained in:
parent
21714f7bcd
commit
6bbe330fdd
3 changed files with 405 additions and 4 deletions
|
|
@ -4,6 +4,14 @@ from .filters import check_filters, CancelFilter
|
||||||
from .. import types
|
from .. import types
|
||||||
|
|
||||||
|
|
||||||
|
class SkipHandler(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CancelHandler(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, dispatcher, once=True):
|
def __init__(self, dispatcher, once=True):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
|
|
@ -11,10 +19,14 @@ class Handler:
|
||||||
|
|
||||||
self.handlers = []
|
self.handlers = []
|
||||||
|
|
||||||
def register(self, handler, filters=None):
|
def register(self, handler, filters=None, index=None):
|
||||||
if filters and not isinstance(filters, (list, tuple, set)):
|
if filters and not isinstance(filters, (list, tuple, set)):
|
||||||
filters = [filters]
|
filters = [filters]
|
||||||
self.handlers.append((filters, handler))
|
record = (filters, handler)
|
||||||
|
if index is None:
|
||||||
|
self.handlers.append(record)
|
||||||
|
else:
|
||||||
|
self.handlers.insert(index, record)
|
||||||
|
|
||||||
def unregister(self, handler):
|
def unregister(self, handler):
|
||||||
for handler_with_filters in self.handlers:
|
for handler_with_filters in self.handlers:
|
||||||
|
|
@ -27,8 +39,13 @@ class Handler:
|
||||||
async def notify(self, *args, **kwargs):
|
async def notify(self, *args, **kwargs):
|
||||||
for filters, handler in self.handlers:
|
for filters, handler in self.handlers:
|
||||||
if await check_filters(filters, args, kwargs):
|
if await check_filters(filters, args, kwargs):
|
||||||
await handler(*args, **kwargs)
|
try:
|
||||||
if self.once:
|
await handler(*args, **kwargs)
|
||||||
|
if self.once:
|
||||||
|
break
|
||||||
|
except SkipHandler:
|
||||||
|
continue
|
||||||
|
except CancelHandler:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
312
aiogram/dispatcher/state.py
Normal file
312
aiogram/dispatcher/state.py
Normal file
|
|
@ -0,0 +1,312 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiogram.dispatcher.handler import SkipHandler
|
||||||
|
|
||||||
|
log = logging.getLogger('aiogram.StateMachine')
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStorage:
|
||||||
|
"""
|
||||||
|
Skeleton for states storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_state_name(value):
|
||||||
|
if callable(value):
|
||||||
|
if hasattr(value, '__name__'):
|
||||||
|
return value.__name__
|
||||||
|
else:
|
||||||
|
return value.__class__.__name__
|
||||||
|
return value
|
||||||
|
|
||||||
|
def set_state(self, chat, user, state):
|
||||||
|
"""
|
||||||
|
Set state
|
||||||
|
:param chat: chat_id
|
||||||
|
:param user: user_id
|
||||||
|
:param state: value
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Get user state from
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def del_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Clear user state
|
||||||
|
:param chat: cha
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def all_states(self, chat=None, user=None, state=None):
|
||||||
|
"""
|
||||||
|
Yield all states (Can use filters)
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
"""
|
||||||
|
Here you can use key or slice-key
|
||||||
|
|
||||||
|
>>> storage[chat:user] = "new state"
|
||||||
|
or
|
||||||
|
>>> storage[chat] = "new state"
|
||||||
|
:param key: key or slice
|
||||||
|
:param value: new state
|
||||||
|
"""
|
||||||
|
if isinstance(key, slice):
|
||||||
|
self.set_state(key.start, key.stop, value)
|
||||||
|
else:
|
||||||
|
self.set_state(key, key, value)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""
|
||||||
|
Here you can use key or slice-key
|
||||||
|
|
||||||
|
>>> storage[chat:user]
|
||||||
|
or
|
||||||
|
>>> storage[chat]
|
||||||
|
:param key: key or slice
|
||||||
|
:return: state
|
||||||
|
"""
|
||||||
|
if isinstance(key, slice):
|
||||||
|
return self.get_state(key.start, key.stop)
|
||||||
|
return self.get_state(key, key)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
if isinstance(key, slice):
|
||||||
|
self.del_state(key.start, key.stop)
|
||||||
|
else:
|
||||||
|
self.del_state(key, key)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
yield from self.all_states()
|
||||||
|
|
||||||
|
|
||||||
|
class StateStorage(BaseStorage):
|
||||||
|
"""
|
||||||
|
Simple in-memory state storage
|
||||||
|
Based on builtin dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.storage = {}
|
||||||
|
|
||||||
|
def _prepare(self, chat, user):
|
||||||
|
"""
|
||||||
|
Add chat and user to storage if they are not exist
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = False
|
||||||
|
|
||||||
|
if chat not in self.storage:
|
||||||
|
self.storage[chat] = {}
|
||||||
|
result = True
|
||||||
|
|
||||||
|
if user not in self.storage[chat]:
|
||||||
|
self.storage[chat][user] = None
|
||||||
|
result = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set_state(self, chat, user, state):
|
||||||
|
self._prepare(chat, user)
|
||||||
|
self.storage[chat][user] = self._prepare_state_name(state)
|
||||||
|
|
||||||
|
def get_state(self, chat, user):
|
||||||
|
self._prepare(chat, user)
|
||||||
|
return self.storage[chat][user]
|
||||||
|
|
||||||
|
def del_state(self, chat, user):
|
||||||
|
self._prepare(chat, user)
|
||||||
|
if self[chat:user] is not None:
|
||||||
|
self[chat:user] = None
|
||||||
|
|
||||||
|
def all_states(self, chat=None, user=None, state=None):
|
||||||
|
for chat_id, chat in self.storage.items():
|
||||||
|
if chat is not None and chat != chat_id:
|
||||||
|
continue
|
||||||
|
for user_id, user_state in chat.items():
|
||||||
|
if user is not None and user != user_id:
|
||||||
|
continue
|
||||||
|
if state is not None and user_state == state:
|
||||||
|
continue
|
||||||
|
yield chat_id, user_id, user_state
|
||||||
|
|
||||||
|
|
||||||
|
class Controller:
|
||||||
|
"""
|
||||||
|
Storage controller
|
||||||
|
|
||||||
|
Make easy access from callback's
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_machine, chat, user, state):
|
||||||
|
self._state_machine = state_machine
|
||||||
|
self._chat = chat
|
||||||
|
self._user = user
|
||||||
|
self._state = state
|
||||||
|
|
||||||
|
def set(self, value):
|
||||||
|
"""
|
||||||
|
Set state
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._state_machine[self._chat:self._user] = value
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get current state
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self._state_machine[self._chat:self._user]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Reset state
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
del self._state_machine[self._chat:self._user]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self._chat}:{self._user} - {self._state}"
|
||||||
|
|
||||||
|
|
||||||
|
class StateMachine:
|
||||||
|
"""
|
||||||
|
Manage state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher, states, storage=None):
|
||||||
|
if storage is None:
|
||||||
|
storage = StateStorage()
|
||||||
|
|
||||||
|
self.steps = self._prepare_states(states)
|
||||||
|
self.storage = storage
|
||||||
|
|
||||||
|
dispatcher.message_handlers.register(self.process_message, index=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_states(states):
|
||||||
|
if isinstance(states, dict):
|
||||||
|
return states
|
||||||
|
elif isinstance(states, (list, tuple, set)):
|
||||||
|
prepared_states = {}
|
||||||
|
for state in states:
|
||||||
|
if not callable(state):
|
||||||
|
raise TypeError('State must be an callable')
|
||||||
|
state_name = state.__name__
|
||||||
|
prepared_states[state_name] = state
|
||||||
|
return prepared_states
|
||||||
|
raise TypeError('States must be an dict or list!')
|
||||||
|
|
||||||
|
def set_state(self, chat, user, state):
|
||||||
|
"""
|
||||||
|
Save state to storage
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param state:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
log.debug(f"Set state for {chat}:{user} to '{state}'")
|
||||||
|
self.storage[chat:user] = state
|
||||||
|
|
||||||
|
def get_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Get state from storage
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.storage[chat:user]
|
||||||
|
|
||||||
|
def del_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Clear user state
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
log.debug(f"Reset state for {chat}:{user}")
|
||||||
|
del self.storage[chat:user]
|
||||||
|
|
||||||
|
async def process_message(self, message):
|
||||||
|
"""
|
||||||
|
Read message and process it
|
||||||
|
:param message:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
chat_id = message.chat.id
|
||||||
|
from_user_id = message.from_user.id
|
||||||
|
|
||||||
|
state = self.get_state(chat_id, from_user_id)
|
||||||
|
if state is None:
|
||||||
|
raise SkipHandler()
|
||||||
|
|
||||||
|
if state not in self.steps:
|
||||||
|
log.warning(f"Found unknown state '{state}' for {chat_id}:{from_user_id}. Condition will be reset.")
|
||||||
|
self.del_state(chat_id, from_user_id)
|
||||||
|
raise SkipHandler()
|
||||||
|
|
||||||
|
callback = self.steps[state]
|
||||||
|
controller = Controller(self, chat_id, from_user_id, state)
|
||||||
|
log.debug(f"Process state for {chat_id}:{from_user_id} - '{state}'")
|
||||||
|
result = await callback(message, controller)
|
||||||
|
# if result is True:
|
||||||
|
# controller.clear()
|
||||||
|
# elif isinstance(result, str):
|
||||||
|
# controller.set(result)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
"""
|
||||||
|
Here you can use key or slice-key
|
||||||
|
|
||||||
|
>>> state[chat:user] = "new state"
|
||||||
|
or
|
||||||
|
>>> state[chat] = "new state"
|
||||||
|
:param key: key or slice
|
||||||
|
:param value: new state
|
||||||
|
"""
|
||||||
|
if isinstance(key, slice):
|
||||||
|
self.set_state(key.start, key.stop, value)
|
||||||
|
else:
|
||||||
|
self.set_state(key, key, value)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""
|
||||||
|
Here you can use key or slice-key
|
||||||
|
|
||||||
|
>>> state[chat:user]
|
||||||
|
or
|
||||||
|
>>> state[chat]
|
||||||
|
:param key: key or slice
|
||||||
|
:return: state
|
||||||
|
"""
|
||||||
|
if isinstance(key, slice):
|
||||||
|
return self.get_state(key.start, key.stop)
|
||||||
|
return self.get_state(key, key)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
"""
|
||||||
|
Reset user state
|
||||||
|
:param key:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if isinstance(key, slice):
|
||||||
|
self.del_state(key.start, key.stop)
|
||||||
|
else:
|
||||||
|
self.del_state(key, key)
|
||||||
72
examples/state_machine.py
Normal file
72
examples/state_machine.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiogram import Bot, types
|
||||||
|
from aiogram.dispatcher import Dispatcher
|
||||||
|
from aiogram.dispatcher.state import StateMachine
|
||||||
|
|
||||||
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
API_TOKEN = '380294876:AAFbdYYgq1hBi9hQDcxD3bj8QCNnVec5aHk'
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
|
users = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start'])
|
||||||
|
async def send_welcome(message: types.Message):
|
||||||
|
await message.reply("Hi there! What's your name?")
|
||||||
|
state.set_state(message.chat.id, message.from_user.id, "name")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_name(message, controller):
|
||||||
|
users[message.from_user.id] = {"name": message.text}
|
||||||
|
|
||||||
|
await message.reply("How old are you?")
|
||||||
|
|
||||||
|
controller.set('age')
|
||||||
|
|
||||||
|
|
||||||
|
async def process_age(message, controller):
|
||||||
|
if not message.text.isdigit():
|
||||||
|
return await message.reply("Age should be a number.\nHow old are you?")
|
||||||
|
|
||||||
|
users[message.from_user.id].update({"age": int(message.text)})
|
||||||
|
|
||||||
|
markup = types.ReplyKeyboardMarkup()
|
||||||
|
markup.add("Male", "Female")
|
||||||
|
markup.add("Other")
|
||||||
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
controller.set("sex")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_sex(message, controller):
|
||||||
|
if message.text not in ["Male", "Female", "Other"]:
|
||||||
|
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||||
|
|
||||||
|
users[message.from_user.id].update({"sex": message.text})
|
||||||
|
controller.clear()
|
||||||
|
|
||||||
|
user = users[message.from_user.id]
|
||||||
|
|
||||||
|
markup = types.ReplyKeyboardRemove()
|
||||||
|
await bot.send_message(message.chat.id,
|
||||||
|
f"Hi!\nNice to meet you, {user['name']}.\nAge: {user['age']}\nSex: {user['sex']}",
|
||||||
|
reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
state = StateMachine(dp, {
|
||||||
|
"name": process_name,
|
||||||
|
"age": process_age,
|
||||||
|
"sex": process_sex
|
||||||
|
})
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(dp.start_pooling())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
loop.stop()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue