Skip to content

Conversation

@dk67604
Copy link

@dk67604 dk67604 commented Dec 20, 2025

Description

This PR enables token-in/token-out workflows for agentic RL by preserving provider-specific token ID data end-to-end when using OpenAI-compatible servers (validated with vLLM).

Changes:

  • Forward per-request model kwargs: invocation_state["model_kwargs"] is now passed through the event loop into Model.stream(). This allows passing OpenAI-compatible extensions like extra_body={"return_token_ids": True} per request.
  • Preserve extra provider fields: messageStop.additionalModelResponseFields is now carried into the final AgentResult.message so downstream RL code can read token IDs without retokenization.
  • Collect token IDs during streaming: OpenAIModel.stream() now collects prompt_token_ids and streamed token_ids (vLLM) and emits them via additionalModelResponseFields.
  • Tests: Added unit tests for plumbing

Related Issues

Fixes #1368

Documentation PR

N/A

Type of Change

New feature

Testing

  • hatch test tests/strands/models/test_openai.py tests/strands/event_loop/test_streaming.py
  • STRANDS_RUN_VLLM_INTEG=1 pytest -q -k vllm_token_ids
  • I ran hatch run prepare

Checklist

  • I have read the CONTRIBUTING document
  • I have added any necessary tests that prove my fix is effective or my feature works
  • I have updated the documentation accordingly (not needed for this change)
  • I have added an appropriate example to the documentation (not needed for this change)
  • My changes generate no new warnings beyond existing repo warnings
  • Any dependent changes have been merged and published (N/A)

@dk67604
Copy link
Author

dk67604 commented Dec 20, 2025

vLLM Test

import os

import pytest

from strands import Agent
from strands.models.openai import OpenAIModel


@pytest.mark.asyncio
async def test_vllm_token_ids_are_preserved_in_agent_result() -> None:
    """Local integration test: vLLM token IDs survive into the final AgentResult.

    This is intentionally opt-in because it requires a locally running vLLM server.

    Run:
        STRANDS_RUN_VLLM_INTEG=1 pytest -q -k vllm_token_ids

    Expected vLLM server (defaults, override via env vars):
        VLLM_BASE_URL=http://localhost:8000/v1
        VLLM_MODEL_ID=AMead10/Llama-3.2-3B-Instruct-AWQ
    """
    if os.getenv("STRANDS_RUN_VLLM_INTEG") != "1":
        pytest.skip("Set STRANDS_RUN_VLLM_INTEG=1 to run vLLM integration test.")

    debug = os.getenv("STRANDS_DEBUG_VLLM_INTEG") == "1"

    base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
    model_id = os.getenv("VLLM_MODEL_ID", "AMead10/Llama-3.2-3B-Instruct-AWQ")

    model = OpenAIModel(
        client_args={"api_key": "EMPTY", "base_url": base_url},
        model_id=model_id,
        params={"max_tokens": 16},
    )
    agent = Agent(model=model)

    events: list[dict] = []
    async for event in agent.stream_async(
        "hi",
        invocation_state={"model_kwargs": {"extra_body": {"return_token_ids": True}}},
    ):
        events.append(event)

    result_event = events[-1]
    assert "result" in result_event, f"unexpected last event: {result_event}"
    agent_result = result_event["result"]

    message = agent_result.message
    additional = message.get("additionalModelResponseFields")

    if debug:
        # Print to stdout so you can inspect full state during local runs:
        # STRANDS_RUN_VLLM_INTEG=1 STRANDS_DEBUG_VLLM_INTEG=1 pytest -q -k vllm_token_ids -s
        print("vLLM integ debug:")
        print("  base_url:", base_url)
        print("  model_id:", model_id)
        print("  assistant_text:", str(agent_result).strip())
        print("  message_keys:", sorted(message.keys()))
        print("  additional_fields_keys:", sorted(additional.keys()) if isinstance(additional, dict) else None)
        if isinstance(additional, dict):
            pti = additional.get("prompt_token_ids")
            ti = additional.get("token_ids")
            print("  prompt_token_ids_len:", len(pti) if isinstance(pti, list) else None)
            print("  prompt_token_ids_head:", pti[:16] if isinstance(pti, list) else None)
            print("  token_ids_len:", len(ti) if isinstance(ti, list) else None)
            print("  token_ids:", ti if isinstance(ti, list) else None)

    assert isinstance(additional, dict), f"missing additionalModelResponseFields in message: {message}"

    prompt_token_ids = additional.get("prompt_token_ids")
    token_ids = additional.get("token_ids")

    assert isinstance(prompt_token_ids, list) and prompt_token_ids, f"missing prompt_token_ids: {additional}"
    assert isinstance(token_ids, list) and token_ids, f"missing token_ids: {additional}"

Output:

{"name": "print", "parameters": {"s": "hi"}}vLLM integ debug:
  base_url: http://localhost:8000/v1
  model_id: AMead10/Llama-3.2-3B-Instruct-AWQ
  assistant_text: {"name": "print", "parameters": {"s": "hi"}}
  message_keys: ['additionalModelResponseFields', 'content', 'role']
  additional_fields_keys: ['prompt_token_ids', 'token_ids']
  prompt_token_ids_len: 92
  prompt_token_ids_head: [128000, 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790]
  token_ids_len: 16
  token_ids: [5018, 609, 794, 330, 1374, 498, 330, 14105, 794, 5324, 82, 794, 330, 6151, 32075, 128008]

Add SGLang native /generate provider with token-in/out and SSE streaming.
Refactor vLLM provider for token-in/out, tool-use streaming, and preserve provider fields.
Update event loop streaming to carry additionalModelResponseFields.
@dk67604 dk67604 force-pushed the feat/token-ids-vllm-openai branch from 9b53a86 to f0ef731 Compare December 21, 2025 21:59
@github-actions github-actions bot added size/xl and removed size/xl labels Dec 21, 2025
@dk67604
Copy link
Author

dk67604 commented Dec 21, 2025

SGLang Test

import os

import pytest

from strands import Agent, tool
from strands.models.sglang import SGLangModel


def _summarize_int_list(values: object, *, head: int = 16, tail: int = 16) -> str:
    if not isinstance(values, list) or not all(isinstance(x, int) for x in values):
        return str(values)
    if len(values) <= head + tail:
        return str(values)
    return f"len={len(values)} head={values[:head]} tail={values[-tail:]}"


@pytest.mark.asyncio
async def test_sglang_token_ids_are_preserved_in_agent_result(capsys) -> None:
    """Local integration test: SGLang `/generate` token IDs survive into the final AgentResult.

    This runs against a local SGLang server:
        SGLANG_BASE_URL=http://localhost:30000
        SGLANG_MODEL_ID=<optional>
    """
    base_url = os.getenv("SGLANG_BASE_URL", "http://localhost:30000")
    model_id = os.getenv("SGLANG_MODEL_ID") or None

    model = SGLangModel(
        base_url=base_url,
        model_id=model_id,
        params={"temperature": 0, "max_new_tokens": 64},
        return_token_ids=True,
    )
    agent = Agent(model=model)

    user_prompt = "hi"
    events: list[dict] = []
    async for event in agent.stream_async(
        user_prompt,
        invocation_state={
            "model_kwargs": {
                "return_token_ids": True,
                # Ensure the model stops normally (avoid Agent MaxTokensReachedException).
                "sampling_params": {"max_new_tokens": 64, "stop": ["\n"]},
            }
        },
    ):
        events.append(event)

    result_event = events[-1]
    assert "result" in result_event, f"unexpected last event: {result_event}"
    result = result_event["result"]

    message = result.message
    additional = message.get("additionalModelResponseFields")

    with capsys.disabled():
        print("SGLang integ debug:")
        print("  base_url:", base_url)
        print("  model_id:", model_id)
        print("  user_prompt:", user_prompt)
        print("  assistant_text:", str(result).strip())
        print("  message_keys:", sorted(message.keys()))
        print("  additional_fields_keys:", sorted(additional.keys()) if isinstance(additional, dict) else None)
        if isinstance(additional, dict):
            print("  prompt_token_ids:", _summarize_int_list(additional.get("prompt_token_ids")))
            print("  token_ids:", _summarize_int_list(additional.get("token_ids")))

    assert isinstance(additional, dict), f"missing additionalModelResponseFields: {message}"
    assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"]
    assert isinstance(additional.get("token_ids"), list) and additional["token_ids"]


@pytest.mark.asyncio
async def test_sglang_token_in_round_trip_prompt_token_ids(capsys) -> None:
    """Local integration test: round-trip token-in using SGLang-returned prompt_token_ids."""
    base_url = os.getenv("SGLANG_BASE_URL", "http://localhost:30000")
    model_id = os.getenv("SGLANG_MODEL_ID") or None

    model = SGLangModel(
        base_url=base_url,
        model_id=model_id,
        params={"temperature": 0, "max_new_tokens": 64},
        return_token_ids=True,
    )
    agent = Agent(model=model)

    # 1) Get prompt_token_ids for a text prompt (token-out).
    user_prompt_1 = "hi"
    res1 = await agent.invoke_async(
        user_prompt_1,
        invocation_state={"model_kwargs": {"return_token_ids": True, "sampling_params": {"max_new_tokens": 64, "stop": ["\n"]}}},
    )
    add1 = res1.message.get("additionalModelResponseFields")
    assert isinstance(add1, dict), f"missing additionalModelResponseFields: {res1.message}"
    pti = add1.get("prompt_token_ids")
    assert isinstance(pti, list) and pti, f"missing prompt_token_ids: {add1}"

    with capsys.disabled():
        print("SGLang token-in debug (step1):")
        print("  user_prompt:", user_prompt_1)
        print("  prompt_token_ids:", _summarize_int_list(pti))
        print("  token_ids:", _summarize_int_list(add1.get("token_ids")))

    # 2) Token-in: send prompt_token_ids back to /generate.
    user_prompt_2 = "ignored"
    res2 = await agent.invoke_async(
        user_prompt_2,
        invocation_state={"model_kwargs": {"prompt_token_ids": pti, "sampling_params": {"max_new_tokens": 64, "stop": ["\n"]}}},
    )
    add2 = res2.message.get("additionalModelResponseFields")
    assert isinstance(add2, dict), f"missing additionalModelResponseFields: {res2.message}"
    assert add2.get("prompt_token_ids") == pti, "token-in call did not preserve prompt_token_ids."
    assert isinstance(add2.get("token_ids"), list) and add2["token_ids"], f"missing token_ids: {add2}"

    with capsys.disabled():
        print("SGLang token-in debug (step2):")
        print("  user_prompt:", user_prompt_2)
        print("  prompt_token_ids:", _summarize_int_list(add2.get("prompt_token_ids")))
        print("  token_ids:", _summarize_int_list(add2.get("token_ids")))


@pytest.mark.asyncio
async def test_sglang_generate_rejects_tool_support(alist) -> None:
    """SGLang `/generate` does not support tools; ensure provider fails loudly."""
    base_url = os.getenv("SGLANG_BASE_URL", "http://localhost:30000")
    model = SGLangModel(base_url=base_url, model_id=None, params={"temperature": 0, "max_new_tokens": 8})

    tool_specs = [
        {
            "name": "echo_tool",
            "description": "Echo input text.",
            "inputSchema": {
                "json": {"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
            },
        }
    ]

    with pytest.raises(TypeError, match="does not support tool_specs/tool_choice"):
        await alist(
            model.stream(
                [{"role": "user", "content": [{"text": "hi"}]}],
                tool_specs=tool_specs,
            )
        )


Output

SGLang integ debug:
  base_url: http://localhost:30000
  model_id: None
  user_prompt: hi
  assistant_text: assistant: Hi! How can I help you today?
  message_keys: ['additionalModelResponseFields', 'content', 'role']
  additional_fields_keys: ['prompt_token_ids', 'token_ids']
  prompt_token_ids: [882, 25, 15960, 198, 78191, 512]
  token_ids: [78191, 25, 21694, 0, 2650, 649, 358, 1520, 499, 3432, 30, 128009]
.SGLang token-in debug (step1):
  user_prompt: hi
  prompt_token_ids: [882, 25, 15960, 198, 78191, 512]
  token_ids: [78191, 25, 21694, 0, 2650, 649, 358, 1520, 499, 3432, 30, 128009]
SGLang token-in debug (step2):
  user_prompt: ignored
  prompt_token_ids: [882, 25, 15960, 198, 78191, 512]
  token_ids: [40, 2846, 6380, 311, 6369, 449, 499, 13, 2650, 596, 701, 1938, 2133, 779, 3117, 1980]
..                                                                                                                                        

@dharamendrak
Copy link

@pgrayy What do you think on integrating this for RL training, I am happy to coordinate on this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] VLLM/SGLang Model - Urgent Need!

2 participants