Skip to content

Commit c12acee

Browse files
committed
fix: add tests to ensure ToolApprovalItem hashability and preserve persisted item count during streamed run resumption
1 parent 0128424 commit c12acee

File tree

1 file changed

+94
-4
lines changed

1 file changed

+94
-4
lines changed

tests/test_hitl_error_scenarios.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
from agents._run_impl import (
2727
NextStepInterruption,
2828
)
29+
from agents.items import MessageOutputItem, ModelResponse
2930
from agents.run_context import RunContextWrapper
3031
from agents.run_state import RunState as RunStateClass
32+
from agents.usage import Usage
3133

3234
from .fake_model import FakeModel
3335
from .test_responses import get_text_message
@@ -174,18 +176,18 @@ async def needs_approval(_ctx, _operation, _call_id) -> bool:
174176
async def test_resuming_pending_mcp_approvals_raises_typeerror():
175177
"""Test that ToolApprovalItem can be added to a set (should be hashable).
176178
177-
At line 783 in _run_impl.py, resolve_interrupted_turn tries:
179+
In resolve_interrupted_turn, the code tries:
178180
pending_hosted_mcp_approvals.add(approval_item)
179181
where approval_item is a ToolApprovalItem. This currently raises TypeError because
180182
ToolApprovalItem is not hashable.
181183
182-
BUG: ToolApprovalItem lacks __hash__, so line 783 will raise TypeError.
184+
BUG: ToolApprovalItem lacks __hash__, so adding it to a set will raise TypeError.
183185
This test will FAIL with TypeError when the bug exists, and PASS when fixed.
184186
"""
185187
model = FakeModel()
186188
agent = Agent(name="TestAgent", model=model, tools=[])
187189

188-
# Create a ToolApprovalItem - this is what line 783 tries to add to a set
190+
# Create a ToolApprovalItem - this is what the code tries to add to a set
189191
mcp_raw_item = {
190192
"type": "hosted_tool_call",
191193
"id": "mcp-approval-1",
@@ -196,7 +198,7 @@ async def test_resuming_pending_mcp_approvals_raises_typeerror():
196198
)
197199

198200
# BUG: This will raise TypeError because ToolApprovalItem is not hashable
199-
# This is exactly what happens at line 783: pending_hosted_mcp_approvals.add(approval_item)
201+
# This is exactly what happens: pending_hosted_mcp_approvals.add(approval_item)
200202
pending_hosted_mcp_approvals: set[ToolApprovalItem] = set()
201203
pending_hosted_mcp_approvals.add(
202204
mcp_approval_item
@@ -775,3 +777,91 @@ async def test_deserializing_interruptions_assumes_function_tool_calls_mcp():
775777
f"from_json assumes all interruptions are function tool calls, "
776778
f"but this is an MCP/hosted tool approval. Error: {e}"
777779
)
780+
781+
782+
@pytest.mark.asyncio
783+
async def test_preserve_persisted_item_counter_when_resuming_streamed_runs():
784+
"""Test that persisted-item counter is preserved when resuming streamed runs.
785+
786+
When constructing RunResultStreaming from a RunState, _current_turn_persisted_item_count
787+
should be preserved from the state, not reset to len(run_state._generated_items). This is
788+
critical for Python-to-Python resumes where the counter accurately reflects how many items
789+
were actually persisted before the interruption.
790+
791+
BUG: When run_state._generated_items is truthy, the code always sets
792+
_current_turn_persisted_item_count to len(run_state._generated_items), overriding the actual
793+
persisted count saved in the state. This causes missing history in sessions when a turn was
794+
interrupted mid-persistence (e.g., 5 items generated but only 3 persisted).
795+
"""
796+
model = FakeModel()
797+
agent = Agent(name="TestAgent", model=model)
798+
799+
# Create a RunState with 5 generated items but only 3 persisted
800+
# This simulates a scenario where a turn was interrupted mid-persistence:
801+
# - 5 items were generated
802+
# - Only 3 items were persisted to the session before interruption
803+
# - The state correctly tracks _current_turn_persisted_item_count=3
804+
context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={})
805+
state = RunState(
806+
context=context_wrapper,
807+
original_input="test input",
808+
starting_agent=agent,
809+
max_turns=10,
810+
)
811+
812+
# Create 5 generated items (simulating multiple outputs before interruption)
813+
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
814+
815+
for i in range(5):
816+
message_item = MessageOutputItem(
817+
agent=agent,
818+
raw_item=ResponseOutputMessage(
819+
id=f"msg_{i}",
820+
type="message",
821+
role="assistant",
822+
status="completed",
823+
content=[
824+
ResponseOutputText(
825+
type="output_text", text=f"Message {i}", annotations=[], logprobs=[]
826+
)
827+
],
828+
),
829+
)
830+
state._generated_items.append(message_item)
831+
832+
# Set the persisted count to 3 (only 3 items were persisted before interruption)
833+
state._current_turn_persisted_item_count = 3
834+
835+
# Add a model response so the state is valid for resumption
836+
state._model_responses = [
837+
ModelResponse(
838+
output=[get_text_message("test")],
839+
usage=Usage(),
840+
response_id="resp_1",
841+
)
842+
]
843+
844+
# Set up model to return final output immediately (so the run completes)
845+
model.set_next_output([get_text_message("done")])
846+
847+
# Resume from state using run_streamed
848+
# BUG: When constructing RunResultStreaming, the code will incorrectly set
849+
# _current_turn_persisted_item_count to len(_generated_items)=5 instead of preserving
850+
# the actual persisted count of 3
851+
result = Runner.run_streamed(agent, state)
852+
853+
# The persisted count should be preserved as 3, not reset to 5
854+
# This test will FAIL when the bug exists (count will be 5)
855+
# and PASS when fixed (count will be 3)
856+
assert result._current_turn_persisted_item_count == 3, (
857+
f"Expected _current_turn_persisted_item_count=3 (the actual persisted count), "
858+
f"but got {result._current_turn_persisted_item_count}. "
859+
f"The bug incorrectly resets the counter to "
860+
f"len(run_state._generated_items)={len(state._generated_items)} instead of "
861+
f"preserving the actual persisted count from the state. This causes missing "
862+
f"history in sessions when resuming after mid-persistence interruptions."
863+
)
864+
865+
# Consume events to complete the run
866+
async for _ in result.stream_events():
867+
pass

0 commit comments

Comments
 (0)