Fix state filter

This commit is contained in:
Alex Root Junior 2018-06-25 03:44:50 +03:00
parent b4ecc421e4
commit 95ba0ae571
2 changed files with 7 additions and 7 deletions

View file

@ -921,9 +921,9 @@ class Dispatcher:
:return: :return:
""" """
if chat is None: if chat is None:
chat = types.Chat.current() chat = types.Chat.current().id
if user is None: if user is None:
user = types.User.current() user = types.User.current().id
return FSMContext(storage=self.storage, chat=chat, user=user) return FSMContext(storage=self.storage, chat=chat, user=user)

View file

@ -121,7 +121,7 @@ class StateFilter(BaseFilter):
def __init__(self, dispatcher, state): def __init__(self, dispatcher, state):
super().__init__(dispatcher) super().__init__(dispatcher)
if isinstance(state, str): if isinstance(state, str) or state is None:
state = (state,) state = (state,)
self.state = state self.state = state
@ -137,19 +137,19 @@ class StateFilter(BaseFilter):
async def check(self, obj): async def check(self, obj):
from ..dispatcher import Dispatcher from ..dispatcher import Dispatcher
if self.state == '*': if '*' in self.state:
return {'state': Dispatcher.current().current_state()} return {'state': Dispatcher.current().current_state()}
try: try:
if self.state == self.ctx_state.get(): if self.ctx_state.get() in self.state:
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
except LookupError: except LookupError:
chat, user = self.get_target(obj) chat, user = self.get_target(obj)
if chat or user: if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state) self.ctx_state.set(state)
if state == self.state: if state in self.state:
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
return False return False