diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920c..2bdadce0d 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once. See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this. +See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`. ## LLM-as-a-judge diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py new file mode 100644 index 000000000..2eeda9989 --- /dev/null +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -0,0 +1,59 @@ +import asyncio + +from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace + + +@function_tool( + name_override="billing_status_checker", + description_override="Answer questions about customer billing status.", +) +def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: + """Return a canned billing answer or a fallback when the question is unrelated.""" + normalized = question.lower() + if "bill" in normalized or "billing" in normalized: + return f"This customer (ID: {customer_id})'s bill is $100" + return "I can only answer questions about billing." + + +def handle_stream(event: AgentToolStreamEvent) -> None: + """Print streaming events emitted by the nested billing agent.""" + stream = event["event"] + tool_call = event.get("tool_call") + tool_call_info = tool_call.call_id if tool_call is not None else "unknown" + print(f"[stream] agent={event['agent'].name} call={tool_call_info} type={stream.type} {stream}") + + +async def main() -> None: + with trace("Agents as tools streaming example"): + billing_agent = Agent( + name="Billing Agent", + instructions="You are a billing agent that answers billing questions.", + model_settings=ModelSettings(tool_choice="required"), + tools=[billing_status_checker], + ) + + billing_agent_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="You are a billing agent that answers billing questions.", + on_stream=handle_stream, + ) + + main_agent = Agent( + name="Customer Support Agent", + instructions=( + "You are a customer support agent. Always call the billing agent to answer billing " + "questions and return the billing agent response to the user." + ), + tools=[billing_agent_tool], + ) + + result = await Runner.run( + main_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + print(f"\nFinal response:\n{result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/financial_research_agent/manager.py b/examples/financial_research_agent/manager.py index 58ec11bf2..6dfc631aa 100644 --- a/examples/financial_research_agent/manager.py +++ b/examples/financial_research_agent/manager.py @@ -6,7 +6,7 @@ from rich.console import Console -from agents import Runner, RunResult, custom_span, gen_trace_id, trace +from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace from .agents.financials_agent import financials_agent from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent @@ -17,7 +17,7 @@ from .printer import Printer -async def _summary_extractor(run_result: RunResult) -> str: +async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str: """Custom output extractor for sub‑agents that return an AnalysisSummary.""" # The financial/risk analyst agents emit an AnalysisSummary with a `summary` field. # We want the tool call to return just that summary text so the writer can drop it inline. diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6f4d0815d..00a5ca21e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -8,6 +8,7 @@ from .agent import ( Agent, AgentBase, + AgentToolStreamEvent, StopAtTools, ToolsToFinalOutputFunction, ToolsToFinalOutputResult, @@ -214,6 +215,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", "AgentBase", + "AgentToolStreamEvent", "StopAtTools", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", diff --git a/src/agents/agent.py b/src/agents/agent.py index c479cc697..d7e780ba9 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -25,15 +25,19 @@ from .prompts import DynamicPromptFunction, Prompt, PromptUtil from .run_context import RunContextWrapper, TContext from .tool import FunctionTool, FunctionToolResult, Tool, function_tool +from .tool_context import ToolContext from .util import _transforms from .util._types import MaybeAwaitable if TYPE_CHECKING: + from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall + from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session - from .result import RunResult + from .result import RunResult, RunResultStreaming from .run import RunConfig + from .stream_events import StreamEvent @dataclass @@ -58,6 +62,19 @@ class ToolsToFinalOutputResult: """ +class AgentToolStreamEvent(TypedDict): + """Streaming event emitted when an agent is invoked as a tool.""" + + event: StreamEvent + """The streaming event from the nested agent run.""" + + agent: Agent[Any] + """The nested agent emitting the event.""" + + tool_call: ResponseFunctionToolCall | None + """The originating tool call, if available.""" + + class StopAtTools(TypedDict): stop_at_tool_names: list[str] """A list of tool names, any of which will stop the agent from running further.""" @@ -382,9 +399,12 @@ def as_tool( self, tool_name: str | None, tool_description: str | None, - custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, + custom_output_extractor: ( + Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None + ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -409,6 +429,10 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + on_stream: Optional callback (sync or async) to receive streaming events from the nested + agent run. The callback receives an `AgentToolStreamEvent` containing the nested + agent, the originating tool call (when available), and each stream event. When + provided, the nested agent is executed in streaming mode. """ @function_tool( @@ -416,26 +440,89 @@ def as_tool( description_override=tool_description or "", is_enabled=is_enabled, ) - async def run_agent(context: RunContextWrapper, input: str) -> Any: + async def run_agent(context: ToolContext, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS - - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - run_config=run_config, - max_turns=resolved_max_turns, - hooks=hooks, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) + run_result: RunResult | RunResultStreaming + + if on_stream is not None: + run_result = Runner.run_streamed( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + # Dispatch callbacks in the background so slow handlers do not block + # event consumption. + event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue() + + async def _run_handler(payload: AgentToolStreamEvent) -> None: + """Execute the user callback while capturing exceptions.""" + try: + maybe_result = on_stream(payload) + if inspect.isawaitable(maybe_result): + await maybe_result + except Exception: + logger.exception( + "Error while handling on_stream event for agent tool %s.", + self.name, + ) + + async def dispatch_stream_events() -> None: + while True: + payload = await event_queue.get() + is_sentinel = payload is None # None marks the end of the stream. + try: + if payload is not None: + await _run_handler(payload) + finally: + event_queue.task_done() + + if is_sentinel: + break + + dispatch_task = asyncio.create_task(dispatch_stream_events()) + + try: + from .stream_events import AgentUpdatedStreamEvent + + current_agent = run_result.current_agent + async for event in run_result.stream_events(): + if isinstance(event, AgentUpdatedStreamEvent): + current_agent = event.new_agent + + payload: AgentToolStreamEvent = { + "event": event, + "agent": current_agent, + "tool_call": context.tool_call, + } + await event_queue.put(payload) + finally: + await event_queue.put(None) + await event_queue.join() + await dispatch_task + else: + run_result = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) if custom_output_extractor: - return await custom_output_extractor(output) + return await custom_output_extractor(run_result) - return output.final_output + return run_result.final_output return run_agent diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 5b81239f6..0fc354299 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -31,6 +31,9 @@ class ToolContext(RunContextWrapper[TContext]): tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) """The raw arguments string of the tool call.""" + tool_call: Optional[ResponseFunctionToolCall] = None + """The tool call object associated with this invocation.""" + @classmethod def from_agent_context( cls, @@ -50,6 +53,11 @@ def from_agent_context( tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments() ) - return cls( - tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values + tool_context = cls( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_arguments=tool_args, + tool_call=tool_call, + **base_values, ) + return tool_context diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 51d8edf20..c28ce8fb1 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,14 +1,17 @@ from __future__ import annotations -from typing import Any +import asyncio +from typing import Any, cast import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from pydantic import BaseModel from agents import ( Agent, AgentBase, + AgentToolStreamEvent, FunctionTool, MessageOutputItem, RunConfig, @@ -18,6 +21,7 @@ Session, TResponseInputItem, ) +from agents.stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent from agents.tool_context import ToolContext @@ -373,3 +377,573 @@ async def extractor(result) -> str: output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streams_events_with_on_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [ + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "streamed output" + self.current_agent = agent + + async def stream_events(self): + for ev in stream_events: + yield ev + + run_calls: list[dict[str, Any]] = [] + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append( + { + "starting_agent": starting_agent, + "input": input, + "context": context, + "max_turns": max_turns, + "hooks": hooks, + "run_config": run_config, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "session": session, + } + ) + return DummyStreamingResult() + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received_events: list[AgentToolStreamEvent] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + received_events.append(payload) + + tool_call = ResponseFunctionToolCall( + id="call_123", + arguments='{"input": "run streaming"}', + call_id="call-123", + name="stream_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ), + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') + + assert output == "streamed output" + assert len(received_events) == len(stream_events) + assert received_events[0]["agent"] is agent + assert received_events[0]["tool_call"] is tool_call + assert received_events[0]["event"] == stream_events[0] + assert run_calls[0]["input"] == "run streaming" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_updates_agent_on_handoff( + monkeypatch: pytest.MonkeyPatch, +) -> None: + first_agent = Agent(name="primary") + handed_off_agent = Agent(name="delegate") + + events = [ + AgentUpdatedStreamEvent(new_agent=first_agent), + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + AgentUpdatedStreamEvent(new_agent=handed_off_agent), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hello"})), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "delegated output" + self.current_agent = first_agent + + async def stream_events(self): + for ev in events: + yield ev + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return DummyStreamingResult() + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + seen_agents: list[Agent[Any]] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + seen_agents.append(payload["agent"]) + + tool = cast( + FunctionTool, + first_agent.as_tool( + tool_name="delegate_tool", + tool_description="Streams handoff events", + on_stream=on_stream, + ), + ) + + tool_call = ResponseFunctionToolCall( + id="call_delegate", + arguments='{"input": "handoff"}', + call_id="call-delegate", + name="delegate_tool", + type="function_call", + ) + tool_context = ToolContext( + context=None, + tool_name="delegate_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "handoff"}') + + assert output == "delegated output" + assert seen_agents == [first_agent, first_agent, handed_off_agent, handed_off_agent] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_works_with_custom_extractor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "raw output" + self.current_agent = agent + + async def stream_events(self): + for ev in stream_events: + yield ev + + streamed_instance = DummyStreamingResult() + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return streamed_instance + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received: list[Any] = [] + + async def extractor(result) -> str: + received.append(result) + return "custom value" + + callbacks: list[Any] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + callbacks.append(payload["event"]) + + tool_call = ResponseFunctionToolCall( + id="call_abc", + arguments='{"input": "stream please"}', + call_id="call-abc", + name="stream_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ), + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') + + assert output == "custom value" + assert received == [streamed_instance] + assert callbacks == stream_events + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_accepts_sync_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="sync_handler_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + self.current_agent = agent + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + calls: list[str] = [] + + def sync_handler(event: AgentToolStreamEvent) -> None: + calls.append(event["event"].type) + + tool_call = ResponseFunctionToolCall( + id="call_sync", + arguments='{"input": "go"}', + call_id="call-sync", + name="sync_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="sync_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + assert calls == ["raw_response_event"] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_dispatches_without_blocking( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """on_stream handlers should not block streaming iteration.""" + agent = Agent(name="nonblocking_agent") + + first_handler_started = asyncio.Event() + allow_handler_to_continue = asyncio.Event() + second_event_yielded = asyncio.Event() + second_event_handled = asyncio.Event() + + first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + second_event = RawResponsesStreamEvent( + data=cast(Any, {"type": "output_text_delta", "delta": "hi"}) + ) + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + self.current_agent = agent + + async def stream_events(self): + yield first_event + second_event_yielded.set() + yield second_event + + dummy_result = DummyStreamingResult() + + monkeypatch.setattr(Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result)) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + async def on_stream(payload: AgentToolStreamEvent) -> None: + if payload["event"] is first_event: + first_handler_started.set() + await allow_handler_to_continue.wait() + else: + second_event_handled.set() + + tool_call = ResponseFunctionToolCall( + id="call_nonblocking", + arguments='{"input": "go"}', + call_id="call-nonblocking", + name="nonblocking_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nonblocking_tool", + tool_description="Uses non-blocking streaming handler", + on_stream=on_stream, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="nonblocking_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + async def _invoke_tool() -> Any: + return await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + invoke_task: asyncio.Task[Any] = asyncio.create_task(_invoke_tool()) + + await asyncio.wait_for(first_handler_started.wait(), timeout=1.0) + await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0) + assert invoke_task.done() is False + + allow_handler_to_continue.set() + await asyncio.wait_for(second_event_handled.wait(), timeout=1.0) + output = await asyncio.wait_for(invoke_task, timeout=1.0) + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="handler_error_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + self.current_agent = agent + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + def bad_handler(event: AgentToolStreamEvent) -> None: + raise RuntimeError("boom") + + tool_call = ResponseFunctionToolCall( + id="call_bad", + arguments='{"input": "go"}', + call_id="call-bad", + name="error_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="error_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_without_stream_uses_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="nostream_agent") + + class DummyResult: + def __init__(self) -> None: + self.final_output = "plain" + + run_calls: list[dict[str, Any]] = [] + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append({"input": input}) + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + monkeypatch.setattr( + Runner, + "run_streamed", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ), + ) + tool_context = ToolContext( + context=None, + tool_name="nostream_tool", + tool_call_id="call-no", + tool_arguments='{"input": "plain"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') + + assert output == "plain" + assert run_calls == [{"input": "plain"}] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_sets_tool_call_from_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="direct_invocation_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + self.current_agent = agent + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + captured: list[AgentToolStreamEvent] = [] + + async def on_stream(event: AgentToolStreamEvent) -> None: + captured.append(event) + + tool_call = ResponseFunctionToolCall( + id="call_direct", + arguments='{"input": "hi"}', + call_id="direct-call-id", + name="direct_stream_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="direct_stream_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') + + assert output == "ok" + assert captured[0]["tool_call"] is tool_call