diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index bbe98be1..8596bc2d 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -98,18 +98,29 @@ class StateFilter(AsyncFilter): self.dispatcher = dispatcher self.state = state + def get_target(self, obj): + return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + async def check(self, obj): if self.state == '*': return True - chat = getattr(getattr(obj, 'chat', None), 'id', None) - user = getattr(getattr(obj, 'from_user', None), 'id', None) + chat, user = self.get_target(obj) if chat or user: return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state return False +class StatesListFilter(StateFilter): + async def check(self, obj): + chat, user = self.get_target(obj) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + return False + + def generate_default_filters(dispatcher, *args, **kwargs): filters_set = [] @@ -128,7 +139,10 @@ def generate_default_filters(dispatcher, *args, **kwargs): elif name == 'func': filters_set.append(filter_) elif name == 'state': - filters_set.append(StateFilter(dispatcher, filter_)) + if isinstance(filter_, (list, set, tuple)): + filters_set.append(filters_set.append(StatesListFilter(dispatcher, filter_))) + else: + filters_set.append(StateFilter(dispatcher, filter_)) elif isinstance(filter_, Filter): filters_set.append(filter_)