From 2987369925ddc70e0074b1ffb40a691504f982a1 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Wed, 15 Nov 2017 18:46:19 +0200 Subject: [PATCH] Update FSM --- aiogram/contrib/fsm_storage/redis.py | 4 +- aiogram/dispatcher/__init__.py | 74 +++++++++++++++++------- examples/finite_state_machine_example.py | 16 ++--- 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 9f54df5a..0104a65f 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -26,12 +26,12 @@ class RedisStorage(BaseStorage): .. code-block:: python3 - dp.storage.close() + await dp.storage.close() await dp.storage.wait_closed() """ - def __init__(self, host, port, db=None, password=None, ssl=None, loop=None, **kwargs): + def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs): self._host = host self._port = port self._db = db diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 0a07974b..f4e7f581 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -117,51 +117,62 @@ class Dispatcher: context.set_value(UPDATE_OBJECT, update) if update.message: if has_context: - state = self.storage.get_state(chat=update.message.chat.id, - user=update.message.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(chat=update.message.chat.id, + user=update.message.from_user.id) + context.update_state(chat=update.message.chat.id, + user=update.message.from_user.id, + state=state) return await self.message_handlers.notify(update.message) if update.edited_message: if has_context: - state = self.storage.get_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id) + context.update_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id, + state=state) return await self.edited_message_handlers.notify(update.edited_message) if update.channel_post: if has_context: - state = self.storage.get_state(chat=update.channel_post.chat.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(chat=update.channel_post.chat.id) + context.update_state(chat=update.channel_post.chat.id, + state=state) return await self.channel_post_handlers.notify(update.channel_post) if update.edited_channel_post: if has_context: - state = self.storage.get_state(chat=update.edited_channel_post.chat.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) + context.update_state(chat=update.edited_channel_post.chat.id, + state=state) return await self.edited_channel_post_handlers.notify(update.edited_channel_post) if update.inline_query: if has_context: - state = self.storage.get_state(user=update.inline_query.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(user=update.inline_query.from_user.id) + context.update_state(user=update.inline_query.from_user.id, + state=state) return await self.inline_query_handlers.notify(update.inline_query) if update.chosen_inline_result: if has_context: - state = self.storage.get_state(user=update.chosen_inline_result.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) + context.update_state(user=update.chosen_inline_result.from_user.id, + state=state) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) if update.callback_query: if has_context: - state = self.storage.get_state(chat=update.callback_query.message.chat.id, - user=update.callback_query.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(chat=update.callback_query.message.chat.id, + user=update.callback_query.from_user.id) + context.update_state(user=update.callback_query.from_user.id, + state=state) return await self.callback_query_handlers.notify(update.callback_query) if update.shipping_query: if has_context: - state = self.storage.get_state(user=update.shipping_query.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(user=update.shipping_query.from_user.id) + context.update_state(user=update.shipping_query.from_user.id, + state=state) return await self.shipping_query_handlers.notify(update.shipping_query) if update.pre_checkout_query: if has_context: - state = self.storage.get_state(user=update.pre_checkout_query.from_user.id) - context.set_value(USER_STATE, await state) + state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) + context.update_state(user=update.pre_checkout_query.from_user.id, + state=state) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) except Exception as e: err = await self.errors_handlers.notify(self, update, e) @@ -824,6 +835,27 @@ class Dispatcher: def current_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None) -> FSMContext: + """ + Get current state for user in chat as context + + .. code-block:: python3 + with dp.current_state(chat=message.chat.id, user=message.user.id) as state: + pass + + state = dp.current_state() + state.set_state('my_state') + + :param chat: + :param user: + :return: + """ + if chat is None: + from .ctx import get_chat + chat = get_chat() + if user is None: + from .ctx import get_user + user = get_user() + return FSMContext(storage=self.storage, chat=chat, user=user) def async_task(self, func): diff --git a/examples/finite_state_machine_example.py b/examples/finite_state_machine_example.py index 9ead40e3..102b1ef3 100644 --- a/examples/finite_state_machine_example.py +++ b/examples/finite_state_machine_example.py @@ -4,6 +4,7 @@ from aiogram import Bot, types from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.dispatcher import Dispatcher from aiogram.types import ParseMode +from aiogram.utils import executor from aiogram.utils.markdown import text, bold API_TOKEN = 'BOT TOKEN HERE' @@ -117,19 +118,14 @@ async def process_sex(message: types.Message): sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) # Finish conversation + # WARNING! This method will destroy all data in storage for current user! await state.finish() -async def main(): - # Skip old updates - count = await dp.skip_updates() - print(f"Skipped {count} updates.") - - await dp.start_pooling() +async def shutdown(dispatcher: Dispatcher): + await dispatcher.storage.close() + await dispatcher.storage.wait_closed() if __name__ == '__main__': - try: - loop.run_until_complete(main()) - except KeyboardInterrupt: - loop.stop() + executor.start_pooling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown)