Add __eq__ and __hash__ methods to State class (#928)

* Fix the ability to copy the state, now copying the state will return the same state.

* reformat

* full implement deepcopy with memo dict, add typehints

* Update aiogram/dispatcher/fsm/state.py

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update tests

Co-authored-by: Oleg A. <t0rr@mail.ru>

* remove deepcopy in tests

Co-authored-by: Oleg A. <t0rr@mail.ru>

* remove deepcopy method

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update changes description

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update __eq__  method

Co-authored-by: Oleg A. <t0rr@mail.ru>

* add typehints, tests

* return False for not equal objects
creating FilterObject use getfullargspec that check State equality with `type` and `object` builtins, raising Error in `__eq__` method of State break this behavior

* return NotImplemented for other types

* use `!=` instead of 'not x == y' in tests

Co-authored-by: Oleg A. <t0rr@mail.ru>
This commit is contained in:
darksidecat 2022-07-08 02:26:49 +03:00 committed by GitHub
parent 416460e013
commit bc5b26de5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 0 deletions

1
CHANGES/927.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Fixed the ability to compare the state, now comparison to copy of the state will return `True`.

View file

@ -54,6 +54,16 @@ class State:
return True return True
return raw_state == self.state return raw_state == self.state
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.state == other.state
if isinstance(other, str):
return self.state == other
return NotImplemented
def __hash__(self) -> int:
return hash(self.state)
class StatesGroupMeta(type): class StatesGroupMeta(type):
__parent__: "Optional[Type[StatesGroup]]" __parent__: "Optional[Type[StatesGroup]]"

View file

@ -1,7 +1,9 @@
from copy import copy
from inspect import isclass from inspect import isclass
import pytest import pytest
from aiogram.dispatcher.event.handler import FilterObject
from aiogram.dispatcher.filters import StateFilter from aiogram.dispatcher.filters import StateFilter
from aiogram.dispatcher.fsm.state import State, StatesGroup from aiogram.dispatcher.fsm.state import State, StatesGroup
from aiogram.types import Update from aiogram.types import Update
@ -50,3 +52,23 @@ class TestStateFilter:
async def test_filter(self, state, current_state, result): async def test_filter(self, state, current_state, result):
f = StateFilter(state=state) f = StateFilter(state=state)
assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result
@pytestmark
async def test_create_filter_from_state(self):
FilterObject(callback=State(state="state"))
@pytestmark
async def test_state_copy(self):
class SG(StatesGroup):
state = State()
assert SG.state == copy(SG.state)
assert SG.state == "SG:state"
assert "SG:state" == SG.state
assert State() == State()
assert SG.state != 1
states = {SG.state: "OK"}
assert states.get(copy(SG.state)) == "OK"