Provide list of states. (StatesListFilter)

This commit is contained in:
Alex Root Junior 2017-08-01 08:09:15 +03:00
parent f70d45c53b
commit 23dcf46a43

View file

@ -98,18 +98,29 @@ class StateFilter(AsyncFilter):
self.dispatcher = dispatcher self.dispatcher = dispatcher
self.state = state self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj): async def check(self, obj):
if self.state == '*': if self.state == '*':
return True return True
chat = getattr(getattr(obj, 'chat', None), 'id', None) chat, user = self.get_target(obj)
user = getattr(getattr(obj, 'from_user', None), 'id', None)
if chat or user: if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False return False
class StatesListFilter(StateFilter):
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
def generate_default_filters(dispatcher, *args, **kwargs): def generate_default_filters(dispatcher, *args, **kwargs):
filters_set = [] filters_set = []
@ -128,7 +139,10 @@ def generate_default_filters(dispatcher, *args, **kwargs):
elif name == 'func': elif name == 'func':
filters_set.append(filter_) filters_set.append(filter_)
elif name == 'state': elif name == 'state':
filters_set.append(StateFilter(dispatcher, filter_)) if isinstance(filter_, (list, set, tuple)):
filters_set.append(filters_set.append(StatesListFilter(dispatcher, filter_)))
else:
filters_set.append(StateFilter(dispatcher, filter_))
elif isinstance(filter_, Filter): elif isinstance(filter_, Filter):
filters_set.append(filter_) filters_set.append(filter_)