From 51160a4084e86438150f606d72fff07f77225f06 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 30 Jun 2026 17:16:44 -0700 Subject: [PATCH] Validate tool results with return types We have to do a bunch of annoying type annotation magic to extract the result type of aggregators from from type annotations. A follow-up question here: should we use the aggregator to populate model_input *here*, instead of in Agent? I will leave that to a second PR. The invocation for parsing tools with pydantic validation is: ``` ai.messages.Message.model_validate( data, context=ai.messages.tool_validate_context([weather]), ) ``` It is also now an argument to `ai.agents.ui.ai_sdk.to_messages`. Happy to bikeshed about names. --- src/ai/agents/agent.py | 92 +++++++++- src/ai/agents/ui/ai_sdk/inbound_messages.py | 34 +++- src/ai/type_utils.py | 118 +++++++++++++ src/ai/types/messages.py | 45 ++++- tests/agents/test_aggregate_marker.py | 62 +++++++ .../agents/ui/ai_sdk/test_inbound_messages.py | 167 ++++++++++++++++++ tests/test_type_utils.py | 78 ++++++++ tests/types/test_messages.py | 107 +++++++++++ 8 files changed, 683 insertions(+), 20 deletions(-) create mode 100644 src/ai/type_utils.py create mode 100644 tests/test_type_utils.py diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 55501b88..5e7ed6d9 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -35,7 +35,7 @@ # Use the typing_extensions backport so this works on 3.12 too. from typing_extensions import TypeVar -from .. import models, types, util +from .. import models, type_utils, types, util from ..types import builders from ..types import events as events_ from ..types.messages import MessageBundle @@ -351,6 +351,67 @@ async def render(prompt: str) -> StreamingTextTool: """ +def _return_type_from_callable(fn: Callable[..., Any]) -> Any: + try: + return get_type_hints(fn, include_extras=True).get("return") + except Exception: + return None + + +def _stream_item_type(return_type: Any) -> Any: + resolved = type_utils.resolve_type_alias(return_type) + if typing.get_origin(resolved) is typing.Annotated: + resolved = typing.get_args(resolved)[0] + + if typing.get_origin(resolved) is not AsyncGenerator: + return None + + args = typing.get_args(resolved) + return args[0] if args else Any + + +def _aggregator_result_type( + aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None, + stream_item_type: Any, +) -> Any: + agg_cls = _aggregator_cls(aggregator) + if agg_cls is None: + return None + + args = type_utils.generic_base_args(agg_cls, events_.Aggregator) + if args is None: + return None + + item_type, result_type, _model_input_type = args + bindings = type_utils.bind_typevars(item_type, stream_item_type) + return type_utils.replace_typevars(result_type, bindings) + + +def _tool_result_type_from_return_annotation( + return_type: Any, + aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None, +) -> Any: + """Return the type used to validate ``ToolResultPart.result``. + + Normal tools store the function's return value directly, so the function's + return annotation is the result type. + + Async-generator tools store ``aggregator.snapshot()`` instead. For those, + read the ``Result`` parameter from ``Aggregator[Item, Result, ModelInput]`` + on the configured aggregator class, binding generic aggregators from the + stream item type when possible (for example ``StreamingStatusTool[str]`` + + ``LastAggregator[T]`` becomes ``str | None``). + """ + if return_type is None: + return None + + item_type = _stream_item_type(return_type) + if item_type is not None: + return _aggregator_result_type(aggregator, item_type) + + return return_type + + def _aggregate_from_return_type(fn: Callable[..., Any]) -> Aggregate | None: """Find an ``Aggregate`` marker in *fn*'s return-type metadata, if any. @@ -360,11 +421,7 @@ def _aggregate_from_return_type(fn: Callable[..., Any]) -> Aggregate | None: * a PEP 695 alias ``type Foo = Annotated[X, Aggregate(...)]``, * a parameterized alias ``type Foo[T] = Annotated[X[T], Aggregate(...)]``. """ - try: - hints = get_type_hints(fn, include_extras=True) - except Exception: - return None - ret = hints.get("return") + ret = _return_type_from_callable(fn) if ret is None: return None @@ -398,6 +455,7 @@ class AgentTool: validator: type[pydantic.BaseModel] | None = None is_gen: bool = False aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None = None + return_type: Any = None @property def name(self) -> str: @@ -435,7 +493,17 @@ def tool[**P, T](fn: Callable[P, AsyncGenerator[T]], /) -> AgentTool: ... @overload def tool[**P]( - *, require_approval: bool + *, + require_approval: bool, + return_type: Any = None, +) -> Callable[[Callable[P, Any]], AgentTool]: ... + + +@overload +def tool[**P]( + *, + return_type: Any, + require_approval: bool = False, ) -> Callable[[Callable[P, Any]], AgentTool]: ... @@ -444,6 +512,7 @@ def tool[**P]( *, aggregator: Callable[[], events_.Aggregator[Any, Any, Any]], require_approval: bool = False, + return_type: Any = None, ) -> Callable[[Callable[P, AsyncGenerator[Any]]], AgentTool]: ... @@ -455,6 +524,7 @@ def tool[**P, T, R]( *, aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None = None, require_approval: bool = False, + return_type: Any = None, ) -> ( Callable[[Callable[P, AsyncGenerator[Any]]], AgentTool] | Callable[[Callable[P, Awaitable[R]]], AgentTool] @@ -466,7 +536,8 @@ def tool[**P, T, R]( ``aggregator=`` keyword argument or by annotating the return type with an :class:`Aggregate` marker (e.g. via the :data:`SubAgentTool` or :data:`StreamingStatusTool` aliases). Specifying both raises - ``TypeError``. + ``TypeError``. Pass ``return_type=`` to override the type used when + validating round-tripped tool results. """ def wrap(fn: Any) -> AgentTool: @@ -483,6 +554,7 @@ def wrap(fn: Any) -> AgentTool: validator = pydantic.create_model(f"{fn.__name__}_Args", **fields) + annotated_return_type = _return_type_from_callable(fn) annotated_aggregate = _aggregate_from_return_type(fn) if annotated_aggregate is not None and aggregator is not None: raise TypeError( @@ -508,6 +580,10 @@ def wrap(fn: Any) -> AgentTool: validator=validator, is_gen=inspect.isasyncgenfunction(fn), aggregator=effective_aggregator, + return_type=return_type + or _tool_result_type_from_return_annotation( + annotated_return_type, effective_aggregator + ), ) if fn is None: diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index 74e312bf..278896b1 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -8,7 +8,7 @@ import json import logging -from typing import Any +from typing import TYPE_CHECKING, Any from ....types import messages as messages_ from ....types.messages import MessageBundle @@ -17,6 +17,9 @@ from .approvals import ApprovalResponse, extract_approvals from .tool_utils import normalize_tool_args +if TYPE_CHECKING: + from ...agent import AgentTool + logger = logging.getLogger(__name__) @@ -64,6 +67,7 @@ def _build_result_part( output: Any, is_error: bool, kind_hint: str | None = None, + validate_context: dict[str, Any] | None = None, ) -> messages_.ToolResultPart: """Reconstruct a tool result from its wire form. @@ -93,16 +97,21 @@ def _build_result_part( else ui_messages_.UIMessage.model_validate(m) for m in raw ] - result = MessageBundle(messages=tuple(_parse(ui_msgs))) + result = MessageBundle( + messages=tuple(_parse(ui_msgs, validate_context=validate_context)) + ) result_kind = "special" else: result = _normalize_tool_result(output) result_kind = "json" - return messages_.ToolResultPart( - tool_call_id=tool_call_id, - tool_name=tool_name, - result=result, - result_kind=result_kind, + return messages_.ToolResultPart.model_validate( + dict( # noqa: C408 + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result, + result_kind=result_kind, + ), + context=validate_context, ) @@ -205,6 +214,8 @@ def _patch_pending_hook_aborts( def _parse( ui_messages: list[ui_messages_.UIMessage], + *, + validate_context: dict[str, Any] | None = None, ) -> list[messages_.Message]: result: list[messages_.Message] = [] @@ -276,6 +287,7 @@ def _parse( kind_hint=result_kinds.get( inv.tool_invocation_id ), + validate_context=validate_context, ) ) @@ -340,6 +352,7 @@ def _parse( output=_tool_result_output(tp), is_error=is_error, kind_hint=result_kinds.get(tp.tool_call_id), + validate_context=validate_context, ) ) if tp.result_provider_metadata is not None: @@ -507,6 +520,8 @@ def _split_assistant_parts( def to_messages( ui_messages: list[ui_messages_.UIMessage], + *, + tools: list[AgentTool] | None = None, ) -> tuple[list[messages_.Message], list[ApprovalResponse]]: """Parse a UI request into runtime messages + extracted approvals. @@ -526,11 +541,14 @@ def to_messages( resolutions via :func:`apply_approvals` before calling :meth:`Agent.run` if the run should resume from a hook. """ + validate_context = ( + messages_.tool_validate_context(tools) if tools is not None else None + ) normalized = _normalize_ui_messages(ui_messages) approval_responses = extract_approvals(normalized) messages = [ m - for m in _parse(normalized) + for m in _parse(normalized, validate_context=validate_context) if not approvals.is_resolved_approval_message(m) ] _patch_pending_hook_aborts(messages, approval_responses) diff --git a/src/ai/type_utils.py b/src/ai/type_utils.py new file mode 100644 index 00000000..78f43366 --- /dev/null +++ b/src/ai/type_utils.py @@ -0,0 +1,118 @@ +"""Helpers for runtime type inspection.""" + +from __future__ import annotations + +import functools +import operator +import types +import typing +from typing import Any + +_T = typing.TypeVar("_T") +_TYPEVAR_TYPE = type(_T) + + +def replace_typevars(value: Any, bindings: dict[Any, Any]) -> Any: + if isinstance(value, _TYPEVAR_TYPE): + return bindings.get(value, value) + + origin = typing.get_origin(value) + if origin is None: + return value + + args = typing.get_args(value) + if not args: + return value + + if origin is typing.Annotated: + return typing.Annotated.__class_getitem__( + (replace_typevars(args[0], bindings), *args[1:]) + ) + + replaced = tuple(replace_typevars(arg, bindings) for arg in args) + if origin in (typing.Union, types.UnionType): + return functools.reduce(operator.or_, replaced) + + return origin[replaced] + + +def bind_typevars(pattern: Any, value: Any) -> dict[Any, Any]: + """Bind type variables in ``pattern`` to matching positions in ``value``.""" + if isinstance(pattern, _TYPEVAR_TYPE): + return {pattern: value} + + pattern_origin = typing.get_origin(pattern) + value_origin = typing.get_origin(value) + if pattern_origin is None or pattern_origin != value_origin: + return {} + + bindings: dict[Any, Any] = {} + for pattern_arg, value_arg in zip( + typing.get_args(pattern), typing.get_args(value), strict=False + ): + bindings.update(bind_typevars(pattern_arg, value_arg)) + return bindings + + +def resolve_type_alias(value: Any) -> Any: + """Resolve a PEP 695 type alias, preserving applied type arguments.""" + if isinstance(value, typing.TypeAliasType): + return value.__value__ + + origin = typing.get_origin(value) + if isinstance(origin, typing.TypeAliasType): + params = origin.__type_params__ + args = typing.get_args(value) + return replace_typevars( + origin.__value__, dict(zip(params, args, strict=False)) + ) + + return value + + +def generic_base_args(child: Any, base: type[Any]) -> tuple[Any, ...] | None: + """Return ``base`` type arguments for ``child``. + + ``child`` may be a concrete class or a parameterized generic alias. For + example, given ``class Box[T](Base[list[T]])``, + ``generic_base_args(Box[int], Base)`` returns ``(list[int],)``. + """ + return _generic_base_args(child, base, {}) + + +def _generic_base_args( + child: Any, + base: type[Any], + bindings: dict[Any, Any], +) -> tuple[Any, ...] | None: + child_origin = typing.get_origin(child) + if child_origin is not None: + child_args = typing.get_args(child) + child_params = getattr(child_origin, "__parameters__", ()) + child_bindings = { + param: replace_typevars(arg, bindings) + for param, arg in zip(child_params, child_args, strict=False) + } + bindings = {**bindings, **child_bindings} + child = child_origin + + for parent in getattr(child, "__orig_bases__", ()): + parent_origin = typing.get_origin(parent) + parent_args = typing.get_args(parent) + + if parent_origin is base: + return tuple(replace_typevars(arg, bindings) for arg in parent_args) + + if isinstance(parent_origin, type): + parent_params = getattr(parent_origin, "__parameters__", ()) + parent_bindings = { + param: replace_typevars(arg, bindings) + for param, arg in zip(parent_params, parent_args, strict=False) + } + result = _generic_base_args( + parent_origin, base, {**bindings, **parent_bindings} + ) + if result is not None: + return result + + return None diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 95e60d38..8f732549 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -3,8 +3,8 @@ import contextvars import functools import random -from collections.abc import AsyncIterator, Callable, Iterator -from typing import Annotated, Any, Literal, Self, overload +from collections.abc import AsyncIterator, Callable, Iterator, Sequence +from typing import Annotated, Any, Literal, Protocol, Self, cast, overload import pydantic @@ -225,6 +225,20 @@ class _ModelInputUnset: ResultKind = Literal["error", "json", "special"] +class ToolResultValidatorTool(Protocol): + @property + def name(self) -> str: ... + + @property + def return_type(self) -> Any: ... + + +def tool_validate_context( + tools: Sequence[ToolResultValidatorTool], +) -> dict[str, dict[str, ToolResultValidatorTool]]: + return {"ai.tools": {tool.name: tool for tool in tools}} + + class ToolResultPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str @@ -259,7 +273,7 @@ class ToolResultPart(pydantic.BaseModel): @pydantic.model_validator(mode="before") @classmethod - def _restore_content(cls, data: Any) -> Any: + def _restore_content(cls, data: Any, info: pydantic.ValidationInfo) -> Any: """Rebuild a typed :class:`SpecialToolResult` after a JSON round-trip. ``result`` is ``Any``, so pydantic restores a serialized @@ -275,11 +289,34 @@ def _restore_content(cls, data: Any) -> Any: data = { **data, "result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python( - data["result"] + data["result"], context=info.context ), } return data + @pydantic.model_validator(mode="after") + def _validate_tool_result(self, info: pydantic.ValidationInfo) -> Self: + if self.is_error or self.result_kind == "special": + return self + + if info.context is None: + return self + + context = cast( + "dict[str, dict[str, ToolResultValidatorTool]]", info.context + ) + tools_by_name = context.get("ai.tools", {}) + tool = tools_by_name.get(self.tool_name) + if tool is None or tool.return_type is None: + return self + + result = pydantic.TypeAdapter(tool.return_type).validate_python( + self.result, context=info.context + ) + if result is not self.result: + object.__setattr__(self, "result", result) + return self + @staticmethod def kind_for(result: Any) -> ResultKind: """Derive ``result_kind`` for a non-error result value. diff --git a/tests/agents/test_aggregate_marker.py b/tests/agents/test_aggregate_marker.py index 22683111..5de6defb 100644 --- a/tests/agents/test_aggregate_marker.py +++ b/tests/agents/test_aggregate_marker.py @@ -24,6 +24,68 @@ def _factory(t: ai.AgentTool) -> object: return factory() +def test_tool_return_type_can_be_overridden() -> None: + @ai.tool(return_type=dict[str, int]) + async def t() -> object: + return {"x": 1} + + assert t.return_type == dict[str, int] + + +def test_aggregator_return_type_can_be_overridden() -> None: + @ai.tool(aggregator=ai.agents.LastAggregator, return_type=int) + async def t() -> AsyncGenerator[str]: + yield "1" + + assert t.return_type is int + + +def test_aggregator_result_type_extracted_from_status_tool_alias() -> None: + @ai.tool + async def t() -> ai.StreamingStatusTool[str]: + yield "x" + + assert t.return_type == str | None + + +def test_aggregator_result_type_extracted_from_streaming_text_alias() -> None: + @ai.tool + async def t() -> ai.StreamingTextTool: + yield "hello" + + assert t.return_type is str + + +def test_aggregator_result_type_extracted_from_sub_agent_alias() -> None: + @ai.tool + async def t() -> ai.SubAgentTool: + yield ai.events.StreamStart() + + assert t.return_type is ai.messages.MessageBundle + + +def test_aggregator_result_type_binds_from_item_to_result() -> None: + class Box[T](ai.events.Aggregator[T, list[T], str]): + def __init__(self) -> None: + self.items: list[T] = [] + + def feed(self, item: T) -> None: + self.items.append(item) + + def snapshot(self) -> list[T]: + return self.items + + @classmethod + def to_model_input(cls, snapshot: list[T]) -> str: + return "" + + @ai.tool(aggregator=Box) + async def t() -> AsyncGenerator[int]: + yield 1 + + assert t.return_type == list[int] + + def test_aggregate_marker_extracted_from_direct_annotated() -> None: """Bare ``Annotated[..., Aggregate(...)]`` on the return type.""" diff --git a/tests/agents/ui/ai_sdk/test_inbound_messages.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py index 463ed263..1cb21187 100644 --- a/tests/agents/ui/ai_sdk/test_inbound_messages.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -1,9 +1,12 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from typing import Any +import pydantic import pytest +import ai from ai.agents.ui.ai_sdk import to_messages, to_ui_messages from ai.agents.ui.ai_sdk.inbound_messages import _normalize_ui_messages from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart @@ -11,6 +14,10 @@ from ai.types.messages import MessageBundle +class InboundWeather(pydantic.BaseModel): + temp: int + + def _ui(role: str, *parts: dict[str, Any], id: str = "m1") -> UIMessage: return UIMessage.model_validate( {"id": id, "role": role, "parts": list(parts)} @@ -316,6 +323,166 @@ def test_content_result_round_trips_via_metadata() -> None: assert file_part.media_type == "image/png" +def test_to_messages_validates_agent_tool_normal_return_annotation() -> None: + @ai.tool + async def weather() -> InboundWeather: + return InboundWeather(temp=1) + + messages, _ = to_messages( + [ + _ui( + "assistant", + _tool( + "weather", + "tc1", + "output-available", + input={}, + output={"temp": "72"}, + ), + id="a1", + ) + ], + tools=[weather], + ) + + tool_msgs = [m for m in messages if m.role == "tool"] + part = tool_msgs[0].tool_results[0] + assert isinstance(part.result, InboundWeather) + assert part.result.temp == 72 + + +def test_to_messages_validates_tool_outputs_with_agent_tool_aggregator() -> ( + None +): + @ai.tool(aggregator=ai.agents.LastAggregator) + async def weather() -> AsyncGenerator[InboundWeather]: + yield InboundWeather(temp=1) + + messages, _ = to_messages( + [ + _ui( + "assistant", + _tool( + "weather", + "tc1", + "output-available", + input={}, + output={"temp": "72"}, + ), + id="a1", + ) + ], + tools=[weather], + ) + + tool_msgs = [m for m in messages if m.role == "tool"] + part = tool_msgs[0].tool_results[0] + assert isinstance(part.result, InboundWeather) + assert part.result.temp == 72 + + +def test_to_messages_validates_agent_tool_annotated_aggregator_output() -> None: + @ai.tool + async def weather() -> ai.StreamingStatusTool[InboundWeather]: + yield InboundWeather(temp=1) + + messages, _ = to_messages( + [ + _ui( + "assistant", + _tool( + "weather", + "tc1", + "output-available", + input={}, + output={"temp": "72"}, + ), + id="a1", + ) + ], + tools=[weather], + ) + + tool_msgs = [m for m in messages if m.role == "tool"] + part = tool_msgs[0].tool_results[0] + assert isinstance(part.result, InboundWeather) + assert part.result.temp == 72 + + +def test_to_messages_validates_agent_tool_annotated_message_aggregator() -> ( + None +): + @ai.tool + async def research() -> ai.SubAgentTool: + yield ai.events.StreamStart() + + messages, _ = to_messages( + [ + _ui( + "assistant", + _tool( + "research", + "tc1", + "output-available", + input={}, + output={ + "type": "messages", + "messages": [ + { + "role": "assistant", + "parts": [{"kind": "text", "text": "ok"}], + } + ], + }, + ), + id="a1", + ) + ], + tools=[research], + ) + + tool_msgs = [m for m in messages if m.role == "tool"] + part = tool_msgs[0].tool_results[0] + assert isinstance(part.result, MessageBundle) + assert part.result.messages[0].text == "ok" + + +def test_to_messages_validates_agent_tool_passed_aggregator_output() -> None: + @ai.tool(aggregator=ai.agents.MessageAggregator) + async def research() -> AsyncGenerator[ai.events.AgentEvent]: + yield ai.events.StreamStart() + + messages, _ = to_messages( + [ + _ui( + "assistant", + _tool( + "research", + "tc1", + "output-available", + input={}, + output={ + "type": "messages", + "messages": [ + { + "role": "assistant", + "parts": [{"kind": "text", "text": "ok"}], + } + ], + }, + ), + id="a1", + ) + ], + tools=[research], + ) + + tool_msgs = [m for m in messages if m.role == "tool"] + part = tool_msgs[0].tool_results[0] + assert isinstance(part.result, MessageBundle) + assert part.result.messages[0].text == "ok" + + def test_to_messages_passthrough_keeps_wire_shape() -> None: """Non-UIMessage tool outputs stay in their wire form.""" ui = [ diff --git a/tests/test_type_utils.py b/tests/test_type_utils.py new file mode 100644 index 00000000..60f26da1 --- /dev/null +++ b/tests/test_type_utils.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any + +from ai import type_utils +from ai.types import events + + +class Base[T, U]: + pass + + +class Box[T](Base[list[T], str]): + pass + + +class IntBox(Box[int]): + pass + + +def test_generic_base_args_from_parameterized_child() -> None: + assert type_utils.generic_base_args(Box[int], Base) == (list[int], str) + + +def test_generic_base_args_through_intermediate_base() -> None: + assert type_utils.generic_base_args(IntBox, Base) == (list[int], str) + + +class ListAggregator[T](events.Aggregator[T, list[T], str]): + def feed(self, item: T) -> None: + pass + + def snapshot(self) -> list[T]: + return [] + + @classmethod + def to_model_input(cls, snapshot: list[T]) -> str: + return "" + + +def test_generic_base_args_from_aggregator() -> None: + assert type_utils.generic_base_args( + ListAggregator[int], events.Aggregator + ) == ( + int, + list[int], + str, + ) + + +def test_bind_typevars_and_replace_typevars() -> None: + args = type_utils.generic_base_args(ListAggregator, events.Aggregator) + assert args is not None + item_type, result_type, _model_input_type = args + + bindings = type_utils.bind_typevars(item_type, int) + assert type_utils.replace_typevars(result_type, bindings) == list[int] + + +# Exercise a non-generic class path too. +class PlainAggregator(events.Aggregator[str, dict[str, Any], str]): + def feed(self, item: str) -> None: + pass + + def snapshot(self) -> dict[str, Any]: + return {} + + @classmethod + def to_model_input(cls, snapshot: dict[str, Any]) -> str: + return "" + + +def test_generic_base_args_from_plain_aggregator() -> None: + assert type_utils.generic_base_args(PlainAggregator, events.Aggregator) == ( + str, + dict[str, Any], + str, + ) diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index 75b47bbb..7e092e10 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -3,16 +3,24 @@ from __future__ import annotations import asyncio +import dataclasses import random import subprocess import sys from typing import Any +import pydantic import pytest from ai.types import messages, usage +@dataclasses.dataclass +class ToolForResultValidation: + name: str + return_type: Any + + def test_usage_add_merges_optional_fields() -> None: a = usage.Usage( input_tokens=100, @@ -211,6 +219,105 @@ def test_tool_result_content_output_with_file_part_round_trip() -> None: assert file_part.media_type == "image/png" +def test_tool_result_validates_result_from_context_tools() -> None: + class Weather(pydantic.BaseModel): + temp: int + city: str + + weather_tool = ToolForResultValidation(name="weather", return_type=Weather) + + restored = messages.ToolResultPart.model_validate( + { + "tool_call_id": "tc", + "tool_name": "weather", + "result": {"temp": "72", "city": "SF"}, + }, + context=messages.tool_validate_context([weather_tool]), + ) + + assert isinstance(restored.result, Weather) + assert restored.result.temp == 72 + assert restored.result.city == "SF" + + +def test_tool_result_uses_tool_name_for_context_lookup() -> None: + class Weather(pydantic.BaseModel): + temp: int + + class Search(pydantic.BaseModel): + hits: list[str] + + weather_tool = ToolForResultValidation(name="weather", return_type=Weather) + search_tool = ToolForResultValidation(name="search", return_type=Search) + + restored = messages.ToolResultPart.model_validate( + { + "tool_call_id": "tc", + "tool_name": "search", + "result": {"hits": ["a", "b"]}, + }, + context=messages.tool_validate_context([weather_tool, search_tool]), + ) + + assert isinstance(restored.result, Search) + assert restored.result.hits == ["a", "b"] + + +def test_message_validates_tool_result_parts_with_context_tools() -> None: + class Weather(pydantic.BaseModel): + temp: int + + weather_tool = ToolForResultValidation(name="weather", return_type=Weather) + + msg = messages.Message.model_validate( + { + "role": "tool", + "parts": [ + { + "kind": "tool_result", + "tool_call_id": "tc", + "tool_name": "weather", + "result": {"temp": "72"}, + } + ], + }, + context=messages.tool_validate_context([weather_tool]), + ) + + part = msg.tool_results[0] + assert isinstance(part.result, Weather) + assert part.result.temp == 72 + + +def test_tool_result_without_context_stores_raw() -> None: + part = messages.ToolResultPart.model_validate( + { + "tool_call_id": "tc", + "tool_name": "weather", + "result": {"temp": "72"}, + } + ) + + assert part.result == {"temp": "72"} + + +def test_tool_result_context_validation_error() -> None: + class Weather(pydantic.BaseModel): + temp: int + + weather_tool = ToolForResultValidation(name="weather", return_type=Weather) + + with pytest.raises(pydantic.ValidationError, match="temp"): + messages.ToolResultPart.model_validate( + { + "tool_call_id": "tc", + "tool_name": "weather", + "result": {"temp": "hot"}, + }, + context=messages.tool_validate_context([weather_tool]), + ) + + def test_tool_result_plain_values_stored_raw() -> None: """Plain str / dict / list / None results are stored as-is and round-trip.