mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Provide list of states. (StatesListFilter)
This commit is contained in:
parent
f70d45c53b
commit
23dcf46a43
1 changed files with 17 additions and 3 deletions
|
|
@ -98,18 +98,29 @@ class StateFilter(AsyncFilter):
|
|||
self.dispatcher = dispatcher
|
||||
self.state = state
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
if self.state == '*':
|
||||
return True
|
||||
|
||||
chat = getattr(getattr(obj, 'chat', None), 'id', None)
|
||||
user = getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
||||
return False
|
||||
|
||||
|
||||
class StatesListFilter(StateFilter):
|
||||
async def check(self, obj):
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
||||
filters_set = []
|
||||
|
||||
|
|
@ -128,6 +139,9 @@ def generate_default_filters(dispatcher, *args, **kwargs):
|
|||
elif name == 'func':
|
||||
filters_set.append(filter_)
|
||||
elif name == 'state':
|
||||
if isinstance(filter_, (list, set, tuple)):
|
||||
filters_set.append(filters_set.append(StatesListFilter(dispatcher, filter_)))
|
||||
else:
|
||||
filters_set.append(StateFilter(dispatcher, filter_))
|
||||
elif isinstance(filter_, Filter):
|
||||
filters_set.append(filter_)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue