mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
add stategroup filter (#659)
This commit is contained in:
parent
71eb5fc44e
commit
04bbc8211c
2 changed files with 41 additions and 6 deletions
|
|
@ -111,8 +111,8 @@ class StatesGroupMeta(type):
|
||||||
return item in cls.__all_states_names__
|
return item in cls.__all_states_names__
|
||||||
if isinstance(item, State):
|
if isinstance(item, State):
|
||||||
return item in cls.__all_states__
|
return item in cls.__all_states__
|
||||||
# if isinstance(item, StatesGroup):
|
if isinstance(item, StatesGroupMeta):
|
||||||
# return item in cls.__all_childs__
|
return item in cls.__all_childs__
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
@ -126,8 +126,11 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
return cls
|
return cls
|
||||||
return cls.__parent__.get_root()
|
return cls.__parent__.get_root()
|
||||||
|
|
||||||
# def __call__(cls, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
def __call__(cls, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
||||||
# return raw_state in cls.__all_states_names__
|
return raw_state in type(cls).__all_states_names__
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"StatesGroup {type(self).__full_group_name__}"
|
||||||
|
|
||||||
|
|
||||||
default_state = State()
|
default_state = State()
|
||||||
|
|
|
||||||
|
|
@ -139,8 +139,7 @@ class TestStatesGroup:
|
||||||
assert MyGroup.state1 not in MyGroup.MyNestedGroup
|
assert MyGroup.state1 not in MyGroup.MyNestedGroup
|
||||||
assert MyGroup.state1 in MyGroup
|
assert MyGroup.state1 in MyGroup
|
||||||
|
|
||||||
# Not working as well
|
assert MyGroup.MyNestedGroup in MyGroup
|
||||||
# assert MyGroup.MyNestedGroup in MyGroup
|
|
||||||
|
|
||||||
assert "MyGroup.MyNestedGroup:state1" in MyGroup
|
assert "MyGroup.MyNestedGroup:state1" in MyGroup
|
||||||
assert "MyGroup.MyNestedGroup:state1" in MyGroup.MyNestedGroup
|
assert "MyGroup.MyNestedGroup:state1" in MyGroup.MyNestedGroup
|
||||||
|
|
@ -150,3 +149,36 @@ class TestStatesGroup:
|
||||||
assert 42 not in MyGroup
|
assert 42 not in MyGroup
|
||||||
|
|
||||||
assert MyGroup.MyNestedGroup.get_root() is MyGroup
|
assert MyGroup.MyNestedGroup.get_root() is MyGroup
|
||||||
|
|
||||||
|
def test_empty_filter(self):
|
||||||
|
class MyGroup(StatesGroup):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert str(MyGroup()) == "StatesGroup MyGroup"
|
||||||
|
|
||||||
|
def test_with_state_filter(self):
|
||||||
|
class MyGroup(StatesGroup):
|
||||||
|
state1 = State()
|
||||||
|
state2 = State()
|
||||||
|
|
||||||
|
assert MyGroup()(None, "MyGroup:state1")
|
||||||
|
assert MyGroup()(None, "MyGroup:state2")
|
||||||
|
assert not MyGroup()(None, "MyGroup:state3")
|
||||||
|
|
||||||
|
assert str(MyGroup()) == "StatesGroup MyGroup"
|
||||||
|
|
||||||
|
def test_nested_group_filter(self):
|
||||||
|
class MyGroup(StatesGroup):
|
||||||
|
state1 = State()
|
||||||
|
|
||||||
|
class MyNestedGroup(StatesGroup):
|
||||||
|
state1 = State()
|
||||||
|
|
||||||
|
assert MyGroup()(None, "MyGroup:state1")
|
||||||
|
assert MyGroup()(None, "MyGroup.MyNestedGroup:state1")
|
||||||
|
assert not MyGroup()(None, "MyGroup:state2")
|
||||||
|
assert MyGroup.MyNestedGroup()(None, "MyGroup.MyNestedGroup:state1")
|
||||||
|
assert not MyGroup.MyNestedGroup()(None, "MyGroup:state1")
|
||||||
|
|
||||||
|
assert str(MyGroup()) == "StatesGroup MyGroup"
|
||||||
|
assert str(MyGroup.MyNestedGroup()) == "StatesGroup MyGroup.MyNestedGroup"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue