Skip to content

Commit cd62267

Browse files
committed
Validate full sampling tool result history
1 parent 161834d commit cd62267

3 files changed

Lines changed: 125 additions & 22 deletions

File tree

src/mcp/server/validation.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from mcp.shared.exceptions import MCPError
8-
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice
8+
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, SamplingMessageContentBlock, Tool, ToolChoice
99

1010

1111
def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool:
@@ -53,6 +53,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
5353
1. Messages with tool_result content contain ONLY tool_result content
5454
2. tool_result messages are preceded by a message with tool_use
5555
3. tool_result IDs match the tool_use IDs from the previous message
56+
4. Every tool_use message in the history is followed by matching tool_result content
5657
5758
See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
5859
@@ -65,24 +66,26 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
6566
if not messages:
6667
return
6768

68-
last_content = messages[-1].content_as_list
69-
has_tool_results = any(c.type == "tool_result" for c in last_content)
70-
71-
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
72-
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
73-
74-
if has_tool_results:
75-
# Per spec: "SamplingMessage with tool result content blocks
76-
# MUST NOT contain other content types."
77-
if any(c.type != "tool_result" for c in last_content):
78-
raise ValueError("The last message must contain only tool_result content if any is present")
79-
if previous_content is None:
80-
raise ValueError("tool_result requires a previous message containing tool_use")
81-
if not has_previous_tool_use:
82-
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
83-
84-
if has_previous_tool_use and previous_content:
85-
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
86-
tool_result_ids = {c.tool_use_id for c in last_content if c.type == "tool_result"}
87-
if tool_use_ids != tool_result_ids:
88-
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
69+
previous_content: list[SamplingMessageContentBlock] | None = None
70+
for content in (message.content_as_list for message in messages):
71+
has_tool_results = any(c.type == "tool_result" for c in content)
72+
previous_tool_use_ids: set[str] = set()
73+
if previous_content is not None:
74+
previous_tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
75+
76+
if has_tool_results:
77+
# Per spec: "SamplingMessage with tool result content blocks
78+
# MUST NOT contain other content types."
79+
if any(c.type != "tool_result" for c in content):
80+
raise ValueError("A message must contain only tool_result content if any is present")
81+
if previous_content is None:
82+
raise ValueError("tool_result requires a previous message containing tool_use")
83+
if not previous_tool_use_ids:
84+
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
85+
86+
if previous_tool_use_ids:
87+
tool_result_ids = {c.tool_use_id for c in content if c.type == "tool_result"}
88+
if previous_tool_use_ids != tool_result_ids:
89+
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
90+
91+
previous_content = content

tests/server/test_session.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,21 @@ async def test_create_message_tool_result_validation():
347347
tools=[tool],
348348
)
349349

350+
# Case 4b: earlier mismatched tool result with a later plain message
351+
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
352+
await session.create_message(
353+
messages=[
354+
types.SamplingMessage(role="assistant", content=tool_use),
355+
types.SamplingMessage(
356+
role="user",
357+
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
358+
),
359+
types.SamplingMessage(role="assistant", content=text),
360+
],
361+
max_tokens=100,
362+
tools=[tool],
363+
)
364+
350365
# Case 5: text-only message with tools (no tool_results) - passes validation
351366
# Covers has_tool_results=False branch.
352367
# We use move_on_after because validation happens synchronously before

tests/server/test_validation.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_ot
108108
validate_tool_use_result_messages(messages)
109109

110110

111+
def test_validate_tool_use_result_messages_raises_for_earlier_mixed_tool_result() -> None:
112+
"""Raises when an earlier message mixes tool_result with other content."""
113+
messages = [
114+
SamplingMessage(
115+
role="assistant",
116+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
117+
),
118+
SamplingMessage(
119+
role="user",
120+
content=[
121+
ToolResultContent(type="tool_result", tool_use_id="tool-1"),
122+
TextContent(type="text", text="also this"),
123+
],
124+
),
125+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
126+
]
127+
128+
with pytest.raises(ValueError, match="only tool_result content"):
129+
validate_tool_use_result_messages(messages)
130+
131+
111132
def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None:
112133
"""Raises when tool_result appears without preceding tool_use."""
113134
messages = [
@@ -136,6 +157,39 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_matc
136157
validate_tool_use_result_messages(messages)
137158

138159

160+
def test_validate_tool_use_result_messages_raises_when_earlier_tool_result_ids_dont_match_tool_use() -> None:
161+
"""Raises when an earlier tool_result does not match the previous tool_use."""
162+
messages = [
163+
SamplingMessage(
164+
role="assistant",
165+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
166+
),
167+
SamplingMessage(
168+
role="user",
169+
content=ToolResultContent(type="tool_result", tool_use_id="tool-2"),
170+
),
171+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
172+
]
173+
174+
with pytest.raises(ValueError, match="do not match"):
175+
validate_tool_use_result_messages(messages)
176+
177+
178+
def test_validate_tool_use_result_messages_raises_when_tool_use_is_not_answered() -> None:
179+
"""Raises when a tool_use is followed by a non-tool_result message."""
180+
messages = [
181+
SamplingMessage(
182+
role="assistant",
183+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
184+
),
185+
SamplingMessage(role="user", content=TextContent(type="text", text="not a result")),
186+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
187+
]
188+
189+
with pytest.raises(ValueError, match="do not match"):
190+
validate_tool_use_result_messages(messages)
191+
192+
139193
def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None:
140194
"""No error when tool_result IDs match tool_use IDs."""
141195
messages = [
@@ -149,3 +203,34 @@ def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_too
149203
),
150204
]
151205
validate_tool_use_result_messages(messages) # Should not raise
206+
207+
208+
def test_validate_tool_use_result_messages_no_error_for_multiple_tool_pairs() -> None:
209+
"""No error when every tool_use in the history has a matching tool_result."""
210+
messages = [
211+
SamplingMessage(role="user", content=TextContent(type="text", text="first")),
212+
SamplingMessage(
213+
role="assistant",
214+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
215+
),
216+
SamplingMessage(
217+
role="user",
218+
content=ToolResultContent(type="tool_result", tool_use_id="tool-1"),
219+
),
220+
SamplingMessage(
221+
role="assistant",
222+
content=[
223+
ToolUseContent(type="tool_use", id="tool-2", name="test", input={}),
224+
ToolUseContent(type="tool_use", id="tool-3", name="test", input={}),
225+
],
226+
),
227+
SamplingMessage(
228+
role="user",
229+
content=[
230+
ToolResultContent(type="tool_result", tool_use_id="tool-3"),
231+
ToolResultContent(type="tool_result", tool_use_id="tool-2"),
232+
],
233+
),
234+
]
235+
236+
validate_tool_use_result_messages(messages)

0 commit comments

Comments
 (0)