Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Literal, cast

from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
Expand Down Expand Up @@ -469,7 +470,41 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -
return 0

target_cutoff = len(messages) - messages_to_keep
return self._find_safe_cutoff_point(messages, target_cutoff)
cutoff = self._find_safe_cutoff_point(messages, target_cutoff)

# Scan preserved messages to find any orphaned ToolCalls
ids_to_find = set()
for i in range(cutoff, len(messages)):
msg = messages[i]
if isinstance(msg, ToolMessage):
ids_to_find.add(msg.tool_call_id)
# If we find the parent in preserved, we are good for that ID
if isinstance(msg, AIMessage):
for tc in msg.tool_calls:
if tc["id"] in ids_to_find:
ids_to_find.remove(tc["id"])

# If ids_to_find is not empty, it means we have orphans.
# We must look BACKWARDS from cutoff to find the parents.
if ids_to_find:
# We iterate backwards from just before the cutoff
for i in range(cutoff - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage):
found_any = False
for tc in msg.tool_calls:
if tc["id"] in ids_to_find:
ids_to_find.remove(tc["id"])
found_any = True

if found_any:
# We found a parent! We must move the cutoff here.
cutoff = i

if not ids_to_find:
break

return cutoff

def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,39 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
# Index 2 is an AIMessage (safe cutoff point), so no adjustment needed
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4)
assert cutoff == 2


def test_summarization_middleware_orphaned_tool_prevention() -> None:
"""Test that middleware prevents orphaned ToolMessages by preserving parent AIMessages."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 2), keep=("messages", 2)
)

# Scenario: AIMessage -> HumanMessage -> ToolMessage
# Standard summarization keeping 1 message would only keep the ToolMessage,
# causing an orphan. Fix should force keeping the AIMessage too.
messages: list[AnyMessage] = [
HumanMessage(content="context"),
AIMessage(content="call", tool_calls=[{"name": "tool", "args": {}, "id": "call-1"}]),
HumanMessage(content="interruption"),
ToolMessage(content="result", tool_call_id="call-1"),
]

state = {"messages": messages}
result = middleware.before_model(state, None) # type: ignore[arg-type]

assert result is not None
assert "messages" in result

# We expect 3 preserved messages: AI, Human, Tool
# Plus the RemoveMessage and the summary HumanMessage
# Total messages in result: Remove + Summary + [AI, Human, Tool] = 5

preserved_messages = result["messages"][2:]
assert len(preserved_messages) == 3
assert isinstance(preserved_messages[0], AIMessage)
assert isinstance(preserved_messages[1], HumanMessage)
assert isinstance(preserved_messages[2], ToolMessage)

assert preserved_messages[0].tool_calls[0]["id"] == "call-1"
assert preserved_messages[2].tool_call_id == "call-1"
Loading