Update FSM

This commit is contained in:
Alex Root Junior 2017-11-15 18:46:19 +02:00
parent 68a4a7a4aa
commit 2987369925
3 changed files with 61 additions and 33 deletions

View file

@ -26,12 +26,12 @@ class RedisStorage(BaseStorage):
.. code-block:: python3 .. code-block:: python3
dp.storage.close() await dp.storage.close()
await dp.storage.wait_closed() await dp.storage.wait_closed()
""" """
def __init__(self, host, port, db=None, password=None, ssl=None, loop=None, **kwargs): def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs):
self._host = host self._host = host
self._port = port self._port = port
self._db = db self._db = db

View file

@ -117,51 +117,62 @@ class Dispatcher:
context.set_value(UPDATE_OBJECT, update) context.set_value(UPDATE_OBJECT, update)
if update.message: if update.message:
if has_context: if has_context:
state = self.storage.get_state(chat=update.message.chat.id, state = await self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id) user=update.message.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(chat=update.message.chat.id,
user=update.message.from_user.id,
state=state)
return await self.message_handlers.notify(update.message) return await self.message_handlers.notify(update.message)
if update.edited_message: if update.edited_message:
if has_context: if has_context:
state = self.storage.get_state(chat=update.edited_message.chat.id, state = await self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id) user=update.edited_message.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id,
state=state)
return await self.edited_message_handlers.notify(update.edited_message) return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post: if update.channel_post:
if has_context: if has_context:
state = self.storage.get_state(chat=update.channel_post.chat.id) state = await self.storage.get_state(chat=update.channel_post.chat.id)
context.set_value(USER_STATE, await state) context.update_state(chat=update.channel_post.chat.id,
state=state)
return await self.channel_post_handlers.notify(update.channel_post) return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post: if update.edited_channel_post:
if has_context: if has_context:
state = self.storage.get_state(chat=update.edited_channel_post.chat.id) state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
context.set_value(USER_STATE, await state) context.update_state(chat=update.edited_channel_post.chat.id,
state=state)
return await self.edited_channel_post_handlers.notify(update.edited_channel_post) return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query: if update.inline_query:
if has_context: if has_context:
state = self.storage.get_state(user=update.inline_query.from_user.id) state = await self.storage.get_state(user=update.inline_query.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(user=update.inline_query.from_user.id,
state=state)
return await self.inline_query_handlers.notify(update.inline_query) return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result: if update.chosen_inline_result:
if has_context: if has_context:
state = self.storage.get_state(user=update.chosen_inline_result.from_user.id) state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(user=update.chosen_inline_result.from_user.id,
state=state)
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query: if update.callback_query:
if has_context: if has_context:
state = self.storage.get_state(chat=update.callback_query.message.chat.id, state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
user=update.callback_query.from_user.id) user=update.callback_query.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(user=update.callback_query.from_user.id,
state=state)
return await self.callback_query_handlers.notify(update.callback_query) return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query: if update.shipping_query:
if has_context: if has_context:
state = self.storage.get_state(user=update.shipping_query.from_user.id) state = await self.storage.get_state(user=update.shipping_query.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(user=update.shipping_query.from_user.id,
state=state)
return await self.shipping_query_handlers.notify(update.shipping_query) return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query: if update.pre_checkout_query:
if has_context: if has_context:
state = self.storage.get_state(user=update.pre_checkout_query.from_user.id) state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.set_value(USER_STATE, await state) context.update_state(user=update.pre_checkout_query.from_user.id,
state=state)
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
except Exception as e: except Exception as e:
err = await self.errors_handlers.notify(self, update, e) err = await self.errors_handlers.notify(self, update, e)
@ -824,6 +835,27 @@ class Dispatcher:
def current_state(self, *, def current_state(self, *,
chat: typing.Union[str, int, None] = None, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> FSMContext: user: typing.Union[str, int, None] = None) -> FSMContext:
"""
Get current state for user in chat as context
.. code-block:: python3
with dp.current_state(chat=message.chat.id, user=message.user.id) as state:
pass
state = dp.current_state()
state.set_state('my_state')
:param chat:
:param user:
:return:
"""
if chat is None:
from .ctx import get_chat
chat = get_chat()
if user is None:
from .ctx import get_user
user = get_user()
return FSMContext(storage=self.storage, chat=chat, user=user) return FSMContext(storage=self.storage, chat=chat, user=user)
def async_task(self, func): def async_task(self, func):

View file

@ -4,6 +4,7 @@ from aiogram import Bot, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher from aiogram.dispatcher import Dispatcher
from aiogram.types import ParseMode from aiogram.types import ParseMode
from aiogram.utils import executor
from aiogram.utils.markdown import text, bold from aiogram.utils.markdown import text, bold
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -117,19 +118,14 @@ async def process_sex(message: types.Message):
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
# Finish conversation # Finish conversation
# WARNING! This method will destroy all data in storage for current user!
await state.finish() await state.finish()
async def main(): async def shutdown(dispatcher: Dispatcher):
# Skip old updates await dispatcher.storage.close()
count = await dp.skip_updates() await dispatcher.storage.wait_closed()
print(f"Skipped {count} updates.")
await dp.start_pooling()
if __name__ == '__main__': if __name__ == '__main__':
try: executor.start_pooling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown)
loop.run_until_complete(main())
except KeyboardInterrupt:
loop.stop()