Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 84 additions & 8 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# Use the typing_extensions backport so this works on 3.12 too.
from typing_extensions import TypeVar

from .. import models, types, util
from .. import models, type_utils, types, util
from ..types import builders
from ..types import events as events_
from ..types.messages import MessageBundle
Expand Down Expand Up @@ -351,6 +351,67 @@ async def render(prompt: str) -> StreamingTextTool:
"""


def _return_type_from_callable(fn: Callable[..., Any]) -> Any:
try:
return get_type_hints(fn, include_extras=True).get("return")
except Exception:
return None


def _stream_item_type(return_type: Any) -> Any:
resolved = type_utils.resolve_type_alias(return_type)
if typing.get_origin(resolved) is typing.Annotated:
resolved = typing.get_args(resolved)[0]

if typing.get_origin(resolved) is not AsyncGenerator:
return None

args = typing.get_args(resolved)
return args[0] if args else Any


def _aggregator_result_type(
aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None,
stream_item_type: Any,
) -> Any:
agg_cls = _aggregator_cls(aggregator)
if agg_cls is None:
return None

args = type_utils.generic_base_args(agg_cls, events_.Aggregator)
if args is None:
return None

item_type, result_type, _model_input_type = args
bindings = type_utils.bind_typevars(item_type, stream_item_type)
return type_utils.replace_typevars(result_type, bindings)


def _tool_result_type_from_return_annotation(
return_type: Any,
aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None,
) -> Any:
"""Return the type used to validate ``ToolResultPart.result``.

Normal tools store the function's return value directly, so the function's
return annotation is the result type.

Async-generator tools store ``aggregator.snapshot()`` instead. For those,
read the ``Result`` parameter from ``Aggregator[Item, Result, ModelInput]``
on the configured aggregator class, binding generic aggregators from the
stream item type when possible (for example ``StreamingStatusTool[str]`` +
``LastAggregator[T]`` becomes ``str | None``).
"""
if return_type is None:
return None

item_type = _stream_item_type(return_type)
if item_type is not None:
return _aggregator_result_type(aggregator, item_type)

return return_type


def _aggregate_from_return_type(fn: Callable[..., Any]) -> Aggregate | None:
"""Find an ``Aggregate`` marker in *fn*'s return-type metadata, if any.

Expand All @@ -360,11 +421,7 @@ def _aggregate_from_return_type(fn: Callable[..., Any]) -> Aggregate | None:
* a PEP 695 alias ``type Foo = Annotated[X, Aggregate(...)]``,
* a parameterized alias ``type Foo[T] = Annotated[X[T], Aggregate(...)]``.
"""
try:
hints = get_type_hints(fn, include_extras=True)
except Exception:
return None
ret = hints.get("return")
ret = _return_type_from_callable(fn)
if ret is None:
return None

Expand Down Expand Up @@ -398,6 +455,7 @@ class AgentTool:
validator: type[pydantic.BaseModel] | None = None
is_gen: bool = False
aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None = None
return_type: Any = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -435,7 +493,17 @@ def tool[**P, T](fn: Callable[P, AsyncGenerator[T]], /) -> AgentTool: ...

@overload
def tool[**P](
*, require_approval: bool
*,
require_approval: bool,
return_type: Any = None,
) -> Callable[[Callable[P, Any]], AgentTool]: ...


@overload
def tool[**P](
*,
return_type: Any,
require_approval: bool = False,
) -> Callable[[Callable[P, Any]], AgentTool]: ...


Expand All @@ -444,6 +512,7 @@ def tool[**P](
*,
aggregator: Callable[[], events_.Aggregator[Any, Any, Any]],
require_approval: bool = False,
return_type: Any = None,
) -> Callable[[Callable[P, AsyncGenerator[Any]]], AgentTool]: ...


Expand All @@ -455,6 +524,7 @@ def tool[**P, T, R](
*,
aggregator: Callable[[], events_.Aggregator[Any, Any, Any]] | None = None,
require_approval: bool = False,
return_type: Any = None,
) -> (
Callable[[Callable[P, AsyncGenerator[Any]]], AgentTool]
| Callable[[Callable[P, Awaitable[R]]], AgentTool]
Expand All @@ -466,7 +536,8 @@ def tool[**P, T, R](
``aggregator=`` keyword argument or by annotating the return type
with an :class:`Aggregate` marker (e.g. via the :data:`SubAgentTool`
or :data:`StreamingStatusTool` aliases). Specifying both raises
``TypeError``.
``TypeError``. Pass ``return_type=`` to override the type used when
validating round-tripped tool results.
"""

def wrap(fn: Any) -> AgentTool:
Expand All @@ -483,6 +554,7 @@ def wrap(fn: Any) -> AgentTool:

validator = pydantic.create_model(f"{fn.__name__}_Args", **fields)

annotated_return_type = _return_type_from_callable(fn)
annotated_aggregate = _aggregate_from_return_type(fn)
if annotated_aggregate is not None and aggregator is not None:
raise TypeError(
Expand All @@ -508,6 +580,10 @@ def wrap(fn: Any) -> AgentTool:
validator=validator,
is_gen=inspect.isasyncgenfunction(fn),
aggregator=effective_aggregator,
return_type=return_type
or _tool_result_type_from_return_annotation(
annotated_return_type, effective_aggregator
),
)

if fn is None:
Expand Down
34 changes: 26 additions & 8 deletions src/ai/agents/ui/ai_sdk/inbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

from ....types import messages as messages_
from ....types.messages import MessageBundle
Expand All @@ -17,6 +17,9 @@
from .approvals import ApprovalResponse, extract_approvals
from .tool_utils import normalize_tool_args

if TYPE_CHECKING:
from ...agent import AgentTool

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -64,6 +67,7 @@ def _build_result_part(
output: Any,
is_error: bool,
kind_hint: str | None = None,
validate_context: dict[str, Any] | None = None,
) -> messages_.ToolResultPart:
"""Reconstruct a tool result from its wire form.

Expand Down Expand Up @@ -93,16 +97,21 @@ def _build_result_part(
else ui_messages_.UIMessage.model_validate(m)
for m in raw
]
result = MessageBundle(messages=tuple(_parse(ui_msgs)))
result = MessageBundle(
messages=tuple(_parse(ui_msgs, validate_context=validate_context))
)
result_kind = "special"
else:
result = _normalize_tool_result(output)
result_kind = "json"
return messages_.ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=result,
result_kind=result_kind,
return messages_.ToolResultPart.model_validate(
dict( # noqa: C408
tool_call_id=tool_call_id,
tool_name=tool_name,
result=result,
result_kind=result_kind,
),
context=validate_context,
)


Expand Down Expand Up @@ -205,6 +214,8 @@ def _patch_pending_hook_aborts(

def _parse(
ui_messages: list[ui_messages_.UIMessage],
*,
validate_context: dict[str, Any] | None = None,
) -> list[messages_.Message]:
result: list[messages_.Message] = []

Expand Down Expand Up @@ -276,6 +287,7 @@ def _parse(
kind_hint=result_kinds.get(
inv.tool_invocation_id
),
validate_context=validate_context,
)
)

Expand Down Expand Up @@ -340,6 +352,7 @@ def _parse(
output=_tool_result_output(tp),
is_error=is_error,
kind_hint=result_kinds.get(tp.tool_call_id),
validate_context=validate_context,
)
)
if tp.result_provider_metadata is not None:
Expand Down Expand Up @@ -507,6 +520,8 @@ def _split_assistant_parts(

def to_messages(
ui_messages: list[ui_messages_.UIMessage],
*,
tools: list[AgentTool] | None = None,
) -> tuple[list[messages_.Message], list[ApprovalResponse]]:
"""Parse a UI request into runtime messages + extracted approvals.

Expand All @@ -526,11 +541,14 @@ def to_messages(
resolutions via :func:`apply_approvals` before calling
:meth:`Agent.run` if the run should resume from a hook.
"""
validate_context = (
messages_.tool_validate_context(tools) if tools is not None else None
)
normalized = _normalize_ui_messages(ui_messages)
approval_responses = extract_approvals(normalized)
messages = [
m
for m in _parse(normalized)
for m in _parse(normalized, validate_context=validate_context)
if not approvals.is_resolved_approval_message(m)
]
_patch_pending_hook_aborts(messages, approval_responses)
Expand Down
118 changes: 118 additions & 0 deletions src/ai/type_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Helpers for runtime type inspection."""

from __future__ import annotations

import functools
import operator
import types
import typing
from typing import Any

_T = typing.TypeVar("_T")
_TYPEVAR_TYPE = type(_T)


def replace_typevars(value: Any, bindings: dict[Any, Any]) -> Any:
if isinstance(value, _TYPEVAR_TYPE):
return bindings.get(value, value)

origin = typing.get_origin(value)
if origin is None:
return value

args = typing.get_args(value)
if not args:
return value

if origin is typing.Annotated:
return typing.Annotated.__class_getitem__(
(replace_typevars(args[0], bindings), *args[1:])
)

replaced = tuple(replace_typevars(arg, bindings) for arg in args)
if origin in (typing.Union, types.UnionType):
return functools.reduce(operator.or_, replaced)

return origin[replaced]


def bind_typevars(pattern: Any, value: Any) -> dict[Any, Any]:
"""Bind type variables in ``pattern`` to matching positions in ``value``."""
if isinstance(pattern, _TYPEVAR_TYPE):
return {pattern: value}

pattern_origin = typing.get_origin(pattern)
value_origin = typing.get_origin(value)
if pattern_origin is None or pattern_origin != value_origin:
return {}

bindings: dict[Any, Any] = {}
for pattern_arg, value_arg in zip(
typing.get_args(pattern), typing.get_args(value), strict=False
):
bindings.update(bind_typevars(pattern_arg, value_arg))
return bindings


def resolve_type_alias(value: Any) -> Any:
"""Resolve a PEP 695 type alias, preserving applied type arguments."""
if isinstance(value, typing.TypeAliasType):
return value.__value__

origin = typing.get_origin(value)
if isinstance(origin, typing.TypeAliasType):
params = origin.__type_params__
args = typing.get_args(value)
return replace_typevars(
origin.__value__, dict(zip(params, args, strict=False))
)

return value


def generic_base_args(child: Any, base: type[Any]) -> tuple[Any, ...] | None:
"""Return ``base`` type arguments for ``child``.

``child`` may be a concrete class or a parameterized generic alias. For
example, given ``class Box[T](Base[list[T]])``,
``generic_base_args(Box[int], Base)`` returns ``(list[int],)``.
"""
return _generic_base_args(child, base, {})


def _generic_base_args(
child: Any,
base: type[Any],
bindings: dict[Any, Any],
) -> tuple[Any, ...] | None:
child_origin = typing.get_origin(child)
if child_origin is not None:
child_args = typing.get_args(child)
child_params = getattr(child_origin, "__parameters__", ())
child_bindings = {
param: replace_typevars(arg, bindings)
for param, arg in zip(child_params, child_args, strict=False)
}
bindings = {**bindings, **child_bindings}
child = child_origin

for parent in getattr(child, "__orig_bases__", ()):
parent_origin = typing.get_origin(parent)
parent_args = typing.get_args(parent)

if parent_origin is base:
return tuple(replace_typevars(arg, bindings) for arg in parent_args)

if isinstance(parent_origin, type):
parent_params = getattr(parent_origin, "__parameters__", ())
parent_bindings = {
param: replace_typevars(arg, bindings)
for param, arg in zip(parent_params, parent_args, strict=False)
}
result = _generic_base_args(
parent_origin, base, {**bindings, **parent_bindings}
)
if result is not None:
return result

return None
Loading
Loading