Add __eq__ and __hash__ methods to State class (#928)

* Fix the ability to copy the state, now copying the state will return the same state.

* reformat

* full implement deepcopy with memo dict, add typehints

* Update aiogram/dispatcher/fsm/state.py

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update tests

Co-authored-by: Oleg A. <t0rr@mail.ru>

* remove deepcopy in tests

Co-authored-by: Oleg A. <t0rr@mail.ru>

* remove deepcopy method

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update changes description

Co-authored-by: Oleg A. <t0rr@mail.ru>

* update __eq__  method

Co-authored-by: Oleg A. <t0rr@mail.ru>

* add typehints, tests

* return False for not equal objects
creating FilterObject use getfullargspec that check State equality with `type` and `object` builtins, raising Error in `__eq__` method of State break this behavior

* return NotImplemented for other types

* use `!=` instead of 'not x == y' in tests

Co-authored-by: Oleg A. <t0rr@mail.ru>
This commit is contained in:
darksidecat 2022-07-08 02:26:49 +03:00 committed by GitHub
parent 416460e013
commit bc5b26de5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 0 deletions

1
CHANGES/927.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Fixed the ability to compare the state, now comparison to copy of the state will return `True`.

View file

@ -54,6 +54,16 @@ class State:
return True
return raw_state == self.state
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.state == other.state
if isinstance(other, str):
return self.state == other
return NotImplemented
def __hash__(self) -> int:
return hash(self.state)
class StatesGroupMeta(type):
__parent__: "Optional[Type[StatesGroup]]"

View file

@ -1,7 +1,9 @@
from copy import copy
from inspect import isclass
import pytest
from aiogram.dispatcher.event.handler import FilterObject
from aiogram.dispatcher.filters import StateFilter
from aiogram.dispatcher.fsm.state import State, StatesGroup
from aiogram.types import Update
@ -50,3 +52,23 @@ class TestStateFilter:
async def test_filter(self, state, current_state, result):
f = StateFilter(state=state)
assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result
@pytestmark
async def test_create_filter_from_state(self):
FilterObject(callback=State(state="state"))
@pytestmark
async def test_state_copy(self):
class SG(StatesGroup):
state = State()
assert SG.state == copy(SG.state)
assert SG.state == "SG:state"
assert "SG:state" == SG.state
assert State() == State()
assert SG.state != 1
states = {SG.state: "OK"}
assert states.get(copy(SG.state)) == "OK"