diff --git a/CHANGES/1368.bugfix.rst b/CHANGES/1368.bugfix.rst new file mode 100644 index 00000000..f541e016 --- /dev/null +++ b/CHANGES/1368.bugfix.rst @@ -0,0 +1 @@ +Fixed a situation where a :code:`CallbackData` could not be parsed without a default value. diff --git a/aiogram/__init__.py b/aiogram/__init__.py index b945226f..9aedf85b 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -14,9 +14,10 @@ from .utils.text_decorations import html_decoration as html from .utils.text_decorations import markdown_decoration as md with suppress(ImportError): - import uvloop as _uvloop import asyncio + import uvloop as _uvloop + asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy()) diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 7a09dedb..7c0dadf8 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from decimal import Decimal from enum import Enum from fractions import Fraction @@ -18,6 +19,7 @@ from uuid import UUID from magic_filter import MagicFilter from pydantic import BaseModel +from pydantic.fields import FieldInfo from aiogram.filters.base import Filter from aiogram.types import CallbackQuery @@ -121,7 +123,7 @@ class CallbackData(BaseModel): payload = {} for k, v in zip(names, parts): # type: str, Optional[str] if field := cls.model_fields.get(k): - if v == "" and not field.is_required(): + if v == "" and _check_field_is_nullable(field): v = None payload[k] = v return cls(**payload) @@ -180,3 +182,19 @@ class CallbackQueryFilter(Filter): if self.rule is None or self.rule.resolve(callback_data): return {"callback_data": callback_data} return False + + +def _check_field_is_nullable(field: FieldInfo) -> bool: + """ + Check if the given field is nullable. + + :param field: The FieldInfo object representing the field to check. + :return: True if the field is nullable, False otherwise. + + """ + if not field.is_required(): + return True + + return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args( + field.annotation + ) diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index e8721a41..4314aa34 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,7 +1,7 @@ from decimal import Decimal from enum import Enum, auto from fractions import Fraction -from typing import Optional +from typing import Optional, Union from uuid import UUID import pytest @@ -147,6 +147,22 @@ class TestCallbackData: assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) + @pytest.mark.parametrize( + "hint", + [ + Union[int, None], + Optional[int], + ], + ) + def test_unpack_optional_wo_default(self, hint): + """Test CallbackData without default optional.""" + + class TgData(CallbackData, prefix="tg"): + chat_id: int + thread_id: hint + + assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + def test_build_filter(self): filter_object = MyCallback.filter(F.foo == "test") assert isinstance(filter_object.rule, MagicFilter) diff --git a/tests/test_fsm/test_state.py b/tests/test_fsm/test_state.py index dd240946..e1986079 100644 --- a/tests/test_fsm/test_state.py +++ b/tests/test_fsm/test_state.py @@ -1,6 +1,7 @@ -import pytest import sys +import pytest + from aiogram.fsm.state import State, StatesGroup, any_state PY312_OR_GREATER = sys.version_info >= (3, 12)