From e17e3bc71c4e95c20655e35f8795a90d651ad65e Mon Sep 17 00:00:00 2001 From: Oleg A Date: Mon, 20 Nov 2023 23:39:09 +0300 Subject: [PATCH] Fix CallbackData without default Optional (#1370) * fix: CallbackData set optional as None * docs: add fix changelog * Add support for nullable fields in callback data This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable. * Add support for nullable fields in callback data This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable. --------- Co-authored-by: JRoot Junior --- CHANGES/1368.bugfix.rst | 1 + aiogram/__init__.py | 3 ++- aiogram/filters/callback_data.py | 20 +++++++++++++++++++- tests/test_filters/test_callback_data.py | 18 +++++++++++++++++- tests/test_fsm/test_state.py | 3 ++- 5 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 CHANGES/1368.bugfix.rst diff --git a/CHANGES/1368.bugfix.rst b/CHANGES/1368.bugfix.rst new file mode 100644 index 00000000..f541e016 --- /dev/null +++ b/CHANGES/1368.bugfix.rst @@ -0,0 +1 @@ +Fixed a situation where a :code:`CallbackData` could not be parsed without a default value. diff --git a/aiogram/__init__.py b/aiogram/__init__.py index b945226f..9aedf85b 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -14,9 +14,10 @@ from .utils.text_decorations import html_decoration as html from .utils.text_decorations import markdown_decoration as md with suppress(ImportError): - import uvloop as _uvloop import asyncio + import uvloop as _uvloop + asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy()) diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 7a09dedb..7c0dadf8 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from decimal import Decimal from enum import Enum from fractions import Fraction @@ -18,6 +19,7 @@ from uuid import UUID from magic_filter import MagicFilter from pydantic import BaseModel +from pydantic.fields import FieldInfo from aiogram.filters.base import Filter from aiogram.types import CallbackQuery @@ -121,7 +123,7 @@ class CallbackData(BaseModel): payload = {} for k, v in zip(names, parts): # type: str, Optional[str] if field := cls.model_fields.get(k): - if v == "" and not field.is_required(): + if v == "" and _check_field_is_nullable(field): v = None payload[k] = v return cls(**payload) @@ -180,3 +182,19 @@ class CallbackQueryFilter(Filter): if self.rule is None or self.rule.resolve(callback_data): return {"callback_data": callback_data} return False + + +def _check_field_is_nullable(field: FieldInfo) -> bool: + """ + Check if the given field is nullable. + + :param field: The FieldInfo object representing the field to check. + :return: True if the field is nullable, False otherwise. + + """ + if not field.is_required(): + return True + + return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args( + field.annotation + ) diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index e8721a41..4314aa34 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,7 +1,7 @@ from decimal import Decimal from enum import Enum, auto from fractions import Fraction -from typing import Optional +from typing import Optional, Union from uuid import UUID import pytest @@ -147,6 +147,22 @@ class TestCallbackData: assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) + @pytest.mark.parametrize( + "hint", + [ + Union[int, None], + Optional[int], + ], + ) + def test_unpack_optional_wo_default(self, hint): + """Test CallbackData without default optional.""" + + class TgData(CallbackData, prefix="tg"): + chat_id: int + thread_id: hint + + assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + def test_build_filter(self): filter_object = MyCallback.filter(F.foo == "test") assert isinstance(filter_object.rule, MagicFilter) diff --git a/tests/test_fsm/test_state.py b/tests/test_fsm/test_state.py index dd240946..e1986079 100644 --- a/tests/test_fsm/test_state.py +++ b/tests/test_fsm/test_state.py @@ -1,6 +1,7 @@ -import pytest import sys +import pytest + from aiogram.fsm.state import State, StatesGroup, any_state PY312_OR_GREATER = sys.version_info >= (3, 12)