mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
"Add get_mounted_bot function and improve model comparison in tests"
In this commit, a new function `get_mounted_bot` was added to `context_controller.py` that returns the bot mounted in context. This function was needed to bypass the limitation in pydantic BaseModel's properties, which neither support computed fields nor serialization/validation. Various tests were also updated to compare models using `model_dump_json()` method rather than comparing the models directly. This change provides more accurate comparisons by considering default values in the models. Further, the dispatcher was adjusted to enforce update re-mounting if the mounted bot differs from the current update. This allows shortcuts to be used in the bot's current instance and ensures the correct propagation of the context to all the nested objects and attributes of Updates.
This commit is contained in:
parent
31c11c31e0
commit
74e00a30b1
4 changed files with 31 additions and 7 deletions
|
|
@ -13,6 +13,13 @@ class BotContextController(BaseModel):
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
self._bot = __context.get("bot") if __context else None
|
self._bot = __context.get("bot") if __context else None
|
||||||
|
|
||||||
|
def get_mounted_bot(self) -> Optional["Bot"]:
|
||||||
|
# Properties are not supported in pydantic BaseModel
|
||||||
|
# @computed_field decorator is not a solution for this case in due to
|
||||||
|
# it produces an additional field in model with validation and serialization that
|
||||||
|
# we don't need here
|
||||||
|
return self._bot
|
||||||
|
|
||||||
def as_(self, bot: Optional["Bot"]) -> Self:
|
def as_(self, bot: Optional["Bot"]) -> Self:
|
||||||
"""
|
"""
|
||||||
Bind object to a bot instance.
|
Bind object to a bot instance.
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,16 @@ class Dispatcher(Router):
|
||||||
handled = False
|
handled = False
|
||||||
start_time = loop.time()
|
start_time = loop.time()
|
||||||
|
|
||||||
token = Bot.set_current(bot)
|
if update.get_mounted_bot() != bot:
|
||||||
|
# Re-mounting update to the current bot instance for making possible to
|
||||||
|
# use it in shortcuts.
|
||||||
|
# Here is update is re-created because we need to propagate context to
|
||||||
|
# all nested objects and attributes of the Update, but it
|
||||||
|
# is impossible without roundtrip to JSON :(
|
||||||
|
# The preferred way is that pass already mounted Bot instance to this update
|
||||||
|
# before call feed_update method
|
||||||
|
update = Update.model_validate(update.model_dump(), context={"bot": bot})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.update.wrap_outer_middleware(
|
response = await self.update.wrap_outer_middleware(
|
||||||
self.update.trigger,
|
self.update.trigger,
|
||||||
|
|
@ -165,7 +174,6 @@ class Dispatcher(Router):
|
||||||
duration,
|
duration,
|
||||||
bot.id,
|
bot.id,
|
||||||
)
|
)
|
||||||
Bot.reset_current(token)
|
|
||||||
|
|
||||||
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
|
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
@ -367,7 +375,7 @@ class Dispatcher(Router):
|
||||||
self, bot: Bot, update: Union[Update, Dict[str, Any]], _timeout: float = 55, **kwargs: Any
|
self, bot: Bot, update: Union[Update, Dict[str, Any]], _timeout: float = 55, **kwargs: Any
|
||||||
) -> Optional[TelegramMethod[TelegramType]]:
|
) -> Optional[TelegramMethod[TelegramType]]:
|
||||||
if not isinstance(update, Update): # Allow to use raw updates
|
if not isinstance(update, Update): # Allow to use raw updates
|
||||||
update = Update(**update)
|
update = Update.model_validate(update, context={"bot": bot})
|
||||||
|
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,10 @@ class MockedSession(BaseSession):
|
||||||
self.requests.append(method)
|
self.requests.append(method)
|
||||||
response: Response[TelegramType] = self.responses.pop()
|
response: Response[TelegramType] = self.responses.pop()
|
||||||
self.check_response(
|
self.check_response(
|
||||||
bot=bot, method=method, status_code=response.error_code, content=response.json()
|
bot=bot,
|
||||||
|
method=method,
|
||||||
|
status_code=response.error_code,
|
||||||
|
content=response.model_dump_json(),
|
||||||
)
|
)
|
||||||
return response.result # type: ignore
|
return response.result # type: ignore
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -460,7 +460,9 @@ class TestDispatcher:
|
||||||
|
|
||||||
@observer()
|
@observer()
|
||||||
async def my_handler(event: Any, **kwargs: Any):
|
async def my_handler(event: Any, **kwargs: Any):
|
||||||
assert event == getattr(update, event_type)
|
assert event.model_dump(exclude_defaults=True) == getattr(
|
||||||
|
update, event_type
|
||||||
|
).model_dump(exclude_defaults=True)
|
||||||
if has_chat:
|
if has_chat:
|
||||||
assert kwargs["event_chat"]
|
assert kwargs["event_chat"]
|
||||||
if has_user:
|
if has_user:
|
||||||
|
|
@ -469,7 +471,9 @@ class TestDispatcher:
|
||||||
|
|
||||||
result = await router.feed_update(bot, update, test="PASS")
|
result = await router.feed_update(bot, update, test="PASS")
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert result["event_update"] == update
|
assert result["event_update"].model_dump(exclude_defaults=True) == update.model_dump(
|
||||||
|
exclude_defaults=True
|
||||||
|
)
|
||||||
assert result["event_router"] == router
|
assert result["event_router"] == router
|
||||||
assert result["test"] == "PASS"
|
assert result["test"] == "PASS"
|
||||||
|
|
||||||
|
|
@ -532,7 +536,9 @@ class TestDispatcher:
|
||||||
)
|
)
|
||||||
result = await dp.feed_update(bot, update, test="PASS")
|
result = await dp.feed_update(bot, update, test="PASS")
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert result["event_update"] == update
|
assert result["event_update"].model_dump(exclude_defaults=True) == update.model_dump(
|
||||||
|
exclude_defaults=True
|
||||||
|
)
|
||||||
assert result["event_router"] == router1
|
assert result["event_router"] == router1
|
||||||
assert result["test"] == "PASS"
|
assert result["test"] == "PASS"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue