From 6e39f9fada99906567d889c9618f65bc2f089c44 Mon Sep 17 00:00:00 2001 From: Daniil Kovalenko <40635760+WhiteMemory99@users.noreply.github.com> Date: Wed, 29 Dec 2021 08:39:28 +0700 Subject: [PATCH] Fix unexpected behavior of sequences in StateFilter (#791) * Fix sequence check behavior in StateFilter * Add sequence cases to StateFilter tests * Add the changelog --- CHANGES/791.bugfix | 1 + aiogram/dispatcher/filters/state.py | 9 +++++---- tests/test_dispatcher/test_filters/test_state.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) create mode 100644 CHANGES/791.bugfix diff --git a/CHANGES/791.bugfix b/CHANGES/791.bugfix new file mode 100644 index 00000000..5a219d4a --- /dev/null +++ b/CHANGES/791.bugfix @@ -0,0 +1 @@ +Fixed unexpected behavior of sequences in the StateFilter. diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py index 978f65e4..316edcf5 100644 --- a/aiogram/dispatcher/filters/state.py +++ b/aiogram/dispatcher/filters/state.py @@ -37,11 +37,12 @@ class StateFilter(BaseFilter): allowed_states = cast(Sequence[StateType], self.state) for allowed_state in allowed_states: if isinstance(allowed_state, str) or allowed_state is None: - if allowed_state == "*": + if allowed_state == "*" or raw_state == allowed_state: return True - return raw_state == allowed_state elif isinstance(allowed_state, (State, StatesGroup)): - return allowed_state(event=obj, raw_state=raw_state) + if allowed_state(event=obj, raw_state=raw_state): + return True elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup): - return allowed_state()(event=obj, raw_state=raw_state) + if allowed_state()(event=obj, raw_state=raw_state): + return True return False diff --git a/tests/test_dispatcher/test_filters/test_state.py b/tests/test_dispatcher/test_filters/test_state.py index d551f748..2d8acda0 100644 --- a/tests/test_dispatcher/test_filters/test_state.py +++ b/tests/test_dispatcher/test_filters/test_state.py @@ -41,6 +41,9 @@ class TestStateFilter: [[None], None, True], [None, "state", False], [[], "state", False], + [[State("state"), "state"], "state", True], + [[MyGroup(), State("state")], "@:state", True], + [[MyGroup, State("state")], "state", False], ], ) @pytestmark