Fix CallbackData without default Optional (#1370)

* fix: CallbackData set optional as None

* docs: add fix changelog

* Add support for nullable fields in callback data

This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable.

* Add support for nullable fields in callback data

This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable.

---------

Co-authored-by: JRoot Junior <jroot.junior@gmail.com>
This commit is contained in:
Oleg A 2023-11-20 23:39:09 +03:00 committed by GitHub
parent ebade3d51f
commit e17e3bc71c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 4 deletions

1
CHANGES/1368.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Fixed a situation where a :code:`CallbackData` could not be parsed without a default value.

View file

@ -14,9 +14,10 @@ from .utils.text_decorations import html_decoration as html
from .utils.text_decorations import markdown_decoration as md from .utils.text_decorations import markdown_decoration as md
with suppress(ImportError): with suppress(ImportError):
import uvloop as _uvloop
import asyncio import asyncio
import uvloop as _uvloop
asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy())

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
@ -18,6 +19,7 @@ from uuid import UUID
from magic_filter import MagicFilter from magic_filter import MagicFilter
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery from aiogram.types import CallbackQuery
@ -121,7 +123,7 @@ class CallbackData(BaseModel):
payload = {} payload = {}
for k, v in zip(names, parts): # type: str, Optional[str] for k, v in zip(names, parts): # type: str, Optional[str]
if field := cls.model_fields.get(k): if field := cls.model_fields.get(k):
if v == "" and not field.is_required(): if v == "" and _check_field_is_nullable(field):
v = None v = None
payload[k] = v payload[k] = v
return cls(**payload) return cls(**payload)
@ -180,3 +182,19 @@ class CallbackQueryFilter(Filter):
if self.rule is None or self.rule.resolve(callback_data): if self.rule is None or self.rule.resolve(callback_data):
return {"callback_data": callback_data} return {"callback_data": callback_data}
return False return False
def _check_field_is_nullable(field: FieldInfo) -> bool:
"""
Check if the given field is nullable.
:param field: The FieldInfo object representing the field to check.
:return: True if the field is nullable, False otherwise.
"""
if not field.is_required():
return True
return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args(
field.annotation
)

View file

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum, auto from enum import Enum, auto
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional, Union
from uuid import UUID from uuid import UUID
import pytest import pytest
@ -147,6 +147,22 @@ class TestCallbackData:
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
@pytest.mark.parametrize(
"hint",
[
Union[int, None],
Optional[int],
],
)
def test_unpack_optional_wo_default(self, hint):
"""Test CallbackData without default optional."""
class TgData(CallbackData, prefix="tg"):
chat_id: int
thread_id: hint
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
def test_build_filter(self): def test_build_filter(self):
filter_object = MyCallback.filter(F.foo == "test") filter_object = MyCallback.filter(F.foo == "test")
assert isinstance(filter_object.rule, MagicFilter) assert isinstance(filter_object.rule, MagicFilter)

View file

@ -1,6 +1,7 @@
import pytest
import sys import sys
import pytest
from aiogram.fsm.state import State, StatesGroup, any_state from aiogram.fsm.state import State, StatesGroup, any_state
PY312_OR_GREATER = sys.version_info >= (3, 12) PY312_OR_GREATER = sys.version_info >= (3, 12)