Provide list of states. (StatesListFilter)

This commit is contained in:
Alex Root Junior 2017-08-01 08:09:15 +03:00
parent f70d45c53b
commit 23dcf46a43

View file

@ -98,18 +98,29 @@ class StateFilter(AsyncFilter):
self.dispatcher = dispatcher
self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if self.state == '*':
return True
chat = getattr(getattr(obj, 'chat', None), 'id', None)
user = getattr(getattr(obj, 'from_user', None), 'id', None)
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False
class StatesListFilter(StateFilter):
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
def generate_default_filters(dispatcher, *args, **kwargs):
filters_set = []
@ -128,6 +139,9 @@ def generate_default_filters(dispatcher, *args, **kwargs):
elif name == 'func':
filters_set.append(filter_)
elif name == 'state':
if isinstance(filter_, (list, set, tuple)):
filters_set.append(filters_set.append(StatesListFilter(dispatcher, filter_)))
else:
filters_set.append(StateFilter(dispatcher, filter_))
elif isinstance(filter_, Filter):
filters_set.append(filter_)