From 95ba0ae571b8286dd8814880af402d6fd22ccaf4 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 25 Jun 2018 03:44:50 +0300 Subject: [PATCH] Fix state filter --- aiogram/dispatcher/dispatcher.py | 4 ++-- aiogram/dispatcher/filters/builtin.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index e926e772..1450451b 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -921,9 +921,9 @@ class Dispatcher: :return: """ if chat is None: - chat = types.Chat.current() + chat = types.Chat.current().id if user is None: - user = types.User.current() + user = types.User.current().id return FSMContext(storage=self.storage, chat=chat, user=user) diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 25aa2a05..495fbf5c 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -121,7 +121,7 @@ class StateFilter(BaseFilter): def __init__(self, dispatcher, state): super().__init__(dispatcher) - if isinstance(state, str): + if isinstance(state, str) or state is None: state = (state,) self.state = state @@ -137,19 +137,19 @@ class StateFilter(BaseFilter): async def check(self, obj): from ..dispatcher import Dispatcher - if self.state == '*': + if '*' in self.state: return {'state': Dispatcher.current().current_state()} try: - if self.state == self.ctx_state.get(): + if self.ctx_state.get() in self.state: return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} except LookupError: chat, user = self.get_target(obj) if chat or user: - state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + state = await self.dispatcher.storage.get_state(chat=chat, user=user) self.ctx_state.set(state) - if state == self.state: + if state in self.state: return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} return False