mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-09 17:33:44 +00:00
Provide async storage.
This commit is contained in:
parent
e03f217aed
commit
1413baf4b8
1 changed files with 300 additions and 3 deletions
|
|
@ -7,9 +7,6 @@ from .handler import SkipHandler
|
||||||
log = logging.getLogger('aiogram.StateMachine')
|
log = logging.getLogger('aiogram.StateMachine')
|
||||||
|
|
||||||
|
|
||||||
# TODO: Provide async storage
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage:
|
class BaseStorage:
|
||||||
"""
|
"""
|
||||||
Skeleton for states storage
|
Skeleton for states storage
|
||||||
|
|
@ -132,6 +129,115 @@ class BaseStorage:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAsyncStorage(BaseStorage):
|
||||||
|
async def set_state(self, chat, user, state):
|
||||||
|
"""
|
||||||
|
Set state
|
||||||
|
|
||||||
|
:param chat: chat_id
|
||||||
|
:param user: user_id
|
||||||
|
:param state: value
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Get user state from
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def del_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Clear user state
|
||||||
|
:param chat: cha
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def all_states(self, chat=None, user=None, state=None):
|
||||||
|
"""
|
||||||
|
Yield all states (Can use filters)
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param state:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def set_value(self, chat, user, key, value):
|
||||||
|
"""
|
||||||
|
Set value for user in storage
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param key:
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_value(self, chat, user, key, default=None):
|
||||||
|
"""
|
||||||
|
Get value from storage
|
||||||
|
|
||||||
|
By default, this method calls `(await self.get_data(chat, user)).get(key, default)`
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param key:
|
||||||
|
:param default:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return (await self.get_data(chat, user)).get(key, default)
|
||||||
|
|
||||||
|
async def del_value(self, chat, user, key):
|
||||||
|
"""
|
||||||
|
Delete value from storage
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param key:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_data(self, chat, user):
|
||||||
|
"""
|
||||||
|
Get all stored data for user
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return: dict
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def update_data(self, chat, user, data):
|
||||||
|
"""
|
||||||
|
Update data in storage
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def clear_data(self, chat, user, key):
|
||||||
|
"""
|
||||||
|
Clear data in storage
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param key:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MemoryStorage(BaseStorage):
|
class MemoryStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
Simple in-memory state storage
|
Simple in-memory state storage
|
||||||
|
|
@ -398,6 +504,116 @@ class Controller:
|
||||||
return f"{self._chat}:{self._user} - {self._state}"
|
return f"{self._chat}:{self._user} - {self._state}"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncController:
|
||||||
|
"""
|
||||||
|
Storage controller
|
||||||
|
|
||||||
|
Make easy access from callback's
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_machine, chat, user, state):
|
||||||
|
self._state_machine = state_machine
|
||||||
|
self._chat = chat
|
||||||
|
self._user = user
|
||||||
|
self._state = state
|
||||||
|
|
||||||
|
async def set_state(self, value):
|
||||||
|
"""
|
||||||
|
Set state
|
||||||
|
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._state_machine.set_state(self._chat, self._user, value)
|
||||||
|
|
||||||
|
async def get_state(self):
|
||||||
|
"""
|
||||||
|
Get current state
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self._state_machine.get_state(self._chat, self._user)
|
||||||
|
|
||||||
|
async def clear(self):
|
||||||
|
"""
|
||||||
|
Reset state
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._state_machine.del_state(self._chat, self._user)
|
||||||
|
|
||||||
|
async def get(self, key, default=None):
|
||||||
|
"""
|
||||||
|
Get value from storage
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param default:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self._state_machine.storage.get_value(self._chat, self._user, key, default)
|
||||||
|
|
||||||
|
async def pop(self, key, default=None):
|
||||||
|
"""
|
||||||
|
Pop item from storage
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param default:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = await self.get(key, default)
|
||||||
|
await self.delete(key)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def set(self, key, value):
|
||||||
|
"""
|
||||||
|
Set new value in user storage
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._state_machine.storage.set_value(self._chat, self._user, key, value)
|
||||||
|
|
||||||
|
async def delete(self, key):
|
||||||
|
"""
|
||||||
|
Delete key from user storage
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._state_machine.storage.del_value(self._chat, self._user, key)
|
||||||
|
|
||||||
|
async def update(self, data):
|
||||||
|
"""
|
||||||
|
Update user storage
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._state_machine.storage.update_data(self._chat, self._user, data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def data(self):
|
||||||
|
"""
|
||||||
|
User data
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self._state_machine.storage.get_data(self._chat, self._user)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
raise RuntimeError("Item assignment not allowed with async storage")
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
raise RuntimeError("Item assignment not allowed with async storage")
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
raise RuntimeError("Item assignment not allowed with async storage")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self._chat}:{self._user} - {self._state}"
|
||||||
|
|
||||||
|
|
||||||
class StateMachine:
|
class StateMachine:
|
||||||
"""
|
"""
|
||||||
Manage state
|
Manage state
|
||||||
|
|
@ -478,3 +694,84 @@ class StateMachine:
|
||||||
callback = self.steps[state]
|
callback = self.steps[state]
|
||||||
controller = Controller(self, chat_id, from_user_id, state)
|
controller = Controller(self, chat_id, from_user_id, state)
|
||||||
await callback(message, controller)
|
await callback(message, controller)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStateMachine:
|
||||||
|
"""
|
||||||
|
Manage state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher, states, storage=None):
|
||||||
|
assert isinstance(storage, BaseAsyncStorage)
|
||||||
|
|
||||||
|
self.steps = self._prepare_states(states)
|
||||||
|
self.storage = storage
|
||||||
|
|
||||||
|
dispatcher.message_handlers.register(self.process_message, index=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_states(states):
|
||||||
|
if isinstance(states, dict):
|
||||||
|
return states
|
||||||
|
elif isinstance(states, (list, tuple, set)):
|
||||||
|
prepared_states = {}
|
||||||
|
for state in states:
|
||||||
|
if not callable(state):
|
||||||
|
raise TypeError('State must be an callable')
|
||||||
|
state_name = state.__name__
|
||||||
|
prepared_states[state_name] = state
|
||||||
|
return prepared_states
|
||||||
|
raise TypeError('States must be an dict or list!')
|
||||||
|
|
||||||
|
async def set_state(self, chat, user, state):
|
||||||
|
"""
|
||||||
|
Save state to storage
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param state:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
log.debug(f"Set state for {chat}:{user} to '{state}'")
|
||||||
|
await self.storage.set_state(chat, user, state)
|
||||||
|
|
||||||
|
async def get_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Get state from storage
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.storage.get_state(chat, user)
|
||||||
|
|
||||||
|
async def del_state(self, chat, user):
|
||||||
|
"""
|
||||||
|
Clear user state
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
log.debug(f"Reset state for {chat}:{user}")
|
||||||
|
await self.storage.del_state(chat, user)
|
||||||
|
|
||||||
|
async def process_message(self, message):
|
||||||
|
"""
|
||||||
|
Read message and process it
|
||||||
|
:param message:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
chat_id = message.chat.id
|
||||||
|
from_user_id = message.from_user.id
|
||||||
|
|
||||||
|
state = await self.get_state(chat_id, from_user_id)
|
||||||
|
if state is None:
|
||||||
|
raise SkipHandler()
|
||||||
|
|
||||||
|
if state not in self.steps:
|
||||||
|
log.warning(f"Found unknown state '{state}' for {chat_id}:{from_user_id}. Condition will be reset.")
|
||||||
|
await self.del_state(chat_id, from_user_id)
|
||||||
|
raise SkipHandler()
|
||||||
|
|
||||||
|
log.debug(f"Process state for {chat_id}:{from_user_id} - '{state}'")
|
||||||
|
callback = self.steps[state]
|
||||||
|
controller = AsyncController(self, chat_id, from_user_id, state)
|
||||||
|
await callback(message, controller)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue