From f70d45c53b97a2b3d5a56669dbfbf4048eb8faa5 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Tue, 1 Aug 2017 07:39:19 +0300 Subject: [PATCH] Start of implementing new FSM. --- aiogram/dispatcher/__init__.py | 141 ++++++++++----- aiogram/dispatcher/filters.py | 28 ++- aiogram/dispatcher/storage.py | 313 +++++++++++++++++++++++++++++++++ 3 files changed, 439 insertions(+), 43 deletions(-) create mode 100644 aiogram/dispatcher/storage.py diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index e40a440b..d693756a 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,9 +1,11 @@ import asyncio import logging +import typing from aiogram.utils.deprecated import deprecated from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters from .handler import Handler, NextStepHandler +from .storage import MemoryStorage, DisabledStorage, BaseStorage, FSMContext from .. import types from ..bot import Bot from ..types.message import ContentType @@ -11,8 +13,6 @@ from ..types.message import ContentType log = logging.getLogger(__name__) -# TODO: Fix functions (functools.wraps(func)) - class Dispatcher: """ Simple Updates dispatcher @@ -22,12 +22,15 @@ class Dispatcher: Provide next step handler and etc. """ - def __init__(self, bot, loop=None): - self.bot: 'Bot' = bot + def __init__(self, bot, loop=None, storage=None): if loop is None: - loop = self.bot.loop + loop = bot.loop + if storage is None: + storage = DisabledStorage() + self.bot: 'Bot' = bot self.loop = loop + self.storage = storage self.last_update_id = 0 @@ -144,7 +147,38 @@ class Dispatcher: """ self._pooling = False - def message_handler(self, commands=None, regexp=None, content_types=None, func=None, custom_filters=None, **kwargs): + def register_message_handler(self, callback, commands=None, regexp=None, content_types=None, func=None, + custom_filters=None, state=None, **kwargs): + """ + You can register messages handler by this method + + :param callback: + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param func: custom any callable object + :param custom_filters: list of custom filters + :param kwargs: + :param state: + :return: decorated function + """ + if content_types is None: + content_types = ContentType.TEXT + if custom_filters is None: + custom_filters = [] + + filters_set = generate_default_filters(self, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + func=func, + state=state, + **kwargs) + self.message_handlers.register(callback, filters_set) + + def message_handler(self, commands=None, regexp=None, content_types=None, func=None, custom_filters=None, + state=None, **kwargs): """ Decorator for messages handler @@ -177,7 +211,7 @@ class Dispatcher: Register multiple filters set for one handler: .. code-block:: python3 - p.messages_handler(commands=['command']) + @dp.messages_handler(commands=['command']) @dp.messages_handler(func=lambda message: demojize(message.text) == ':new_moon_with_face:') async def text_handler(message: types.Message): This handler will be called if the message starts with '/command' OR is some emoji @@ -190,39 +224,29 @@ class Dispatcher: :param func: custom any callable object :param custom_filters: list of custom filters :param kwargs: + :param state: :return: decorated function """ - if commands is None: - commands = [] - if content_types is None: - content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(*custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - **kwargs) - - def decorator(handler): - self.message_handlers.register(handler, filters_set) - return handler + def decorator(callback): + self.register_message_handler(callback, commands=commands, regexp=regexp, content_types=content_types, + func=func, custom_filters=custom_filters, state=state, **kwargs) + return callback return decorator - def edited_message_handler(self, commands=None, regexp=None, content_types=None, func=None, custom_filters=None, - **kwargs): + def register_edited_message_handler(self, callback, commands=None, regexp=None, content_types=None, func=None, + custom_filters=None, **kwargs): """ Analog of message_handler but only for edited messages You can use combination of different handlers .. code-block:: python3 - @dp.message_handler() - @dp.edited_message_handler() - async def msg_handler(message: types.Message): + @dp.message_handler() + @dp.edited_message_handler() + async def msg_handler(message: types.Message): + :param callback: :param commands: list of commands :param regexp: REGEXP :param content_types: List of content types. @@ -238,16 +262,40 @@ class Dispatcher: if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, commands=commands, regexp=regexp, content_types=content_types, func=func, **kwargs) + self.edited_message_handlers.register(callback, filters_set) - def decorator(handler): - self.edited_message_handlers.register(handler, filters_set) - return handler + def edited_message_handler(self, commands=None, regexp=None, content_types=None, func=None, custom_filters=None, + **kwargs): + """ + Analog of message_handler but only for edited messages + + You can use combination of different handlers + .. code-block:: python3 + @dp.message_handler() + @dp.edited_message_handler() + async def msg_handler(message: types.Message): + + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param func: custom any callable object + :param custom_filters: list of custom filters + :param kwargs: + :return: decorated function + """ + + def decorator(callback): + self.register_edited_message_handler(callback, commands=commands, regexp=regexp, + content_types=content_types, func=func, custom_filters=custom_filters, + **kwargs) + return callback return decorator @@ -271,7 +319,8 @@ class Dispatcher: if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, commands=commands, regexp=regexp, content_types=content_types, @@ -304,7 +353,8 @@ class Dispatcher: if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, commands=commands, regexp=regexp, content_types=content_types, @@ -333,7 +383,8 @@ class Dispatcher: """ if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, func=func, **kwargs) @@ -359,7 +410,8 @@ class Dispatcher: """ if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, func=func, **kwargs) @@ -384,7 +436,8 @@ class Dispatcher: """ if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, func=func, **kwargs) @@ -409,7 +462,8 @@ class Dispatcher: """ if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, func=func, **kwargs) @@ -434,7 +488,8 @@ class Dispatcher: """ if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, func=func, **kwargs) @@ -452,10 +507,16 @@ class Dispatcher: if custom_filters is None: custom_filters = [] - filters_set = generate_default_filters(*custom_filters, + filters_set = generate_default_filters(self, + *custom_filters, regexp=regexp, content_types=content_types, func=func, **kwargs) self.next_step_message_handlers.register(message, otherwise, once, include_cancel, filters_set) return await self.next_step_message_handlers.wait(message) + + async def current_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None) -> FSMContext: + return FSMContext(storage=self.storage, chat=chat, user=user) diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index 4d92a839..bbe98be1 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -79,7 +79,7 @@ class ContentTypeFilter(Filter): self.content_types = content_types def check(self, message): - return message.content_type in self.content_types + return message.content_type[0] in self.content_types class CancelFilter(Filter): @@ -93,11 +93,28 @@ class CancelFilter(Filter): return message.text.lower() in self.cancel_set -def generate_default_filters(*args, **kwargs): +class StateFilter(AsyncFilter): + def __init__(self, dispatcher, state): + self.dispatcher = dispatcher + self.state = state + + async def check(self, obj): + if self.state == '*': + return True + + chat = getattr(getattr(obj, 'chat', None), 'id', None) + user = getattr(getattr(obj, 'from_user', None), 'id', None) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state + return False + + +def generate_default_filters(dispatcher, *args, **kwargs): filters_set = [] for name, filter_ in kwargs.items(): - if not filter_: + if filter_ is None and name != 'state': continue if name == 'commands': if isinstance(filter_, str): @@ -110,6 +127,10 @@ def generate_default_filters(*args, **kwargs): filters_set.append(ContentTypeFilter(filter_)) elif name == 'func': filters_set.append(filter_) + elif name == 'state': + filters_set.append(StateFilter(dispatcher, filter_)) + elif isinstance(filter_, Filter): + filters_set.append(filter_) filters_set += list(args) @@ -123,3 +144,4 @@ class DefaultFilters(Helper): REGEXP = Item() # regexp CONTENT_TYPE = Item() # content_type FUNC = Item() # func + STATE = Item() # state diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py new file mode 100644 index 00000000..62a0c9fd --- /dev/null +++ b/aiogram/dispatcher/storage.py @@ -0,0 +1,313 @@ +import typing + + +class BaseStorage: + """ + In states-storage you can save current user state and data for all steps + """ + + @classmethod + def check_address(cls, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None) -> (typing.Union[str, int], typing.Union[str, int]): + """ + In all methods of storage chat or user is always required. + If one of this is not presented, need set the missing value based on the presented. + + This method performs the above action. + + :param chat: + :param user: + :return: + """ + if chat is not None and user is not None: + return chat, user + elif user is None and chat is not None: + user = chat + return chat, user + elif user is not None and chat is None: + chat = user + return chat, user + raise ValueError('User or chat parameters is required but anyone is not presented!') + + async def get_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + """ + Get current state of user in chat. Return value stored in `default` parameter if record is not found. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param default: + :return: + """ + raise NotImplementedError + + async def get_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Dict: + """ + Get state-data for user in chat. Return `default` if data is not presented in storage. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param default: + :return: + """ + raise NotImplementedError + + async def set_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + state: typing.Optional[typing.AnyStr] = None): + """ + Setup new state for user in chat + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param state: + """ + raise NotImplementedError + + async def set_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None): + """ + Set data for user in chat + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param data: + """ + raise NotImplementedError + + async def update_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None, + **kwargs): + """ + Update data for user in chat + + You can use data parameter or|and kwargs. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param data: + :param chat: + :param user: + :param kwargs: + :return: + """ + raise NotImplementedError + + async def reset_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None): + """ + Reset data dor user in chat. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :return: + """ + await self.set_data(chat=chat, user=user, data={}) + + async def reset_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + with_data: typing.Optional[bool] = True): + """ + Reset state for user in chat. You can use this method for finish conversations. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param with_data: + :return: + """ + chat, user = self.check_address(chat=chat, user=user) + await self.set_state(chat=chat, user=user, state=None) + if with_data: + await self.set_data(chat=chat, user=user, data={}) + + async def finish(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None): + """ + Finish conversation for user in chat. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :return: + """ + await self.reset_state(chat=chat, user=user, with_data=True) + + +class FSMContext: + def __init__(self, storage, chat, user): + self.storage: BaseStorage = storage + self.chat, self.user = self.storage.check_address(chat=chat, user=user) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]: + return await self.storage.get_state(chat=self.chat, user=self.user, default=default) + + async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict: + return await self.storage.get_data(chat=self.chat, user=self.user, default=default) + + async def update_data(self, data: typing.Dict = None, **kwargs): + await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs) + + async def set_state(self, state: typing.AnyStr): + await self.storage.set_state(chat=self.chat, user=self.user, state=state) + + async def set_data(self, data: typing.Dict = None): + await self.storage.set_data(chat=self.chat, user=self.user, data=data) + + async def reset_state(self, with_data: typing.Optional[bool] = True): + await self.storage.reset_state(chat=self.chat, user=self.user, with_data=with_data) + + async def reset_data(self): + await self.storage.reset_data(chat=self.chat, user=self.user) + + async def finish(self): + await self.storage.finish(chat=self.chat, user=self.user) + + +class DisabledStorage(BaseStorage): + """ + Empty storage. Use it if you don't want to use Finite-State Machine + """ + + async def get_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + return None + + async def get_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Dict: + return {} + + async def update_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None, **kwargs): + pass + + async def set_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + state: typing.AnyStr = None): + pass + + async def set_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None): + pass + + +class MemoryStorage(BaseStorage): + """ + In-memory based states storage. + + This type of storage is not recommended for usage in bots, because you will lost all states after restarting. + """ + + def __init__(self): + self.data = {} + + def _get_chat(self, chat_id): + chat_id = str(chat_id) + if chat_id not in self.data: + self.data[chat_id] = {} + return self.data[chat_id] + + def _get_user(self, chat_id, user_id): + chat = self._get_chat(chat_id) + chat_id = str(chat_id) + user_id = str(user_id) + if user_id not in self.data[chat_id]: + self.data[chat_id][user_id] = {'state': None, 'data': {}} + return self.data[chat_id][user_id] + + async def get_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + return user['state'] + + async def get_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + return user['data'] + + async def update_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None, **kwargs): + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + user['data'].update(data, kwargs) + + async def set_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + state: typing.AnyStr = None): + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + user['state'] = state + + async def set_data(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + data: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + user['data'] = data + + async def reset_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + with_data: typing.Optional[bool] = True): + await self.set_state(chat=chat, user=user, state=None) + if with_data: + await self.set_data(chat=chat, user=user, data={})