diff --git a/aiogram/dispatcher/fsm/state.py b/aiogram/dispatcher/fsm/state.py index d4ea1974..ced9779a 100644 --- a/aiogram/dispatcher/fsm/state.py +++ b/aiogram/dispatcher/fsm/state.py @@ -111,8 +111,8 @@ class StatesGroupMeta(type): return item in cls.__all_states_names__ if isinstance(item, State): return item in cls.__all_states__ - # if isinstance(item, StatesGroup): - # return item in cls.__all_childs__ + if isinstance(item, StatesGroupMeta): + return item in cls.__all_childs__ return False def __str__(self) -> str: @@ -126,8 +126,11 @@ class StatesGroup(metaclass=StatesGroupMeta): return cls return cls.__parent__.get_root() - # def __call__(cls, event: TelegramObject, raw_state: Optional[str] = None) -> bool: - # return raw_state in cls.__all_states_names__ + def __call__(cls, event: TelegramObject, raw_state: Optional[str] = None) -> bool: + return raw_state in type(cls).__all_states_names__ + + def __str__(self) -> str: + return f"StatesGroup {type(self).__full_group_name__}" default_state = State() diff --git a/tests/test_dispatcher/test_fsm/test_state.py b/tests/test_dispatcher/test_fsm/test_state.py index 07037e86..04c8e448 100644 --- a/tests/test_dispatcher/test_fsm/test_state.py +++ b/tests/test_dispatcher/test_fsm/test_state.py @@ -139,8 +139,7 @@ class TestStatesGroup: assert MyGroup.state1 not in MyGroup.MyNestedGroup assert MyGroup.state1 in MyGroup - # Not working as well - # assert MyGroup.MyNestedGroup in MyGroup + assert MyGroup.MyNestedGroup in MyGroup assert "MyGroup.MyNestedGroup:state1" in MyGroup assert "MyGroup.MyNestedGroup:state1" in MyGroup.MyNestedGroup @@ -150,3 +149,36 @@ class TestStatesGroup: assert 42 not in MyGroup assert MyGroup.MyNestedGroup.get_root() is MyGroup + + def test_empty_filter(self): + class MyGroup(StatesGroup): + pass + + assert str(MyGroup()) == "StatesGroup MyGroup" + + def test_with_state_filter(self): + class MyGroup(StatesGroup): + state1 = State() + state2 = State() + + assert MyGroup()(None, "MyGroup:state1") + assert MyGroup()(None, "MyGroup:state2") + assert not MyGroup()(None, "MyGroup:state3") + + assert str(MyGroup()) == "StatesGroup MyGroup" + + def test_nested_group_filter(self): + class MyGroup(StatesGroup): + state1 = State() + + class MyNestedGroup(StatesGroup): + state1 = State() + + assert MyGroup()(None, "MyGroup:state1") + assert MyGroup()(None, "MyGroup.MyNestedGroup:state1") + assert not MyGroup()(None, "MyGroup:state2") + assert MyGroup.MyNestedGroup()(None, "MyGroup.MyNestedGroup:state1") + assert not MyGroup.MyNestedGroup()(None, "MyGroup:state1") + + assert str(MyGroup()) == "StatesGroup MyGroup" + assert str(MyGroup.MyNestedGroup()) == "StatesGroup MyGroup.MyNestedGroup"