Skip to content

Commit 21769ae

Browse files
committed
Preserve in-progress message history on agent run error
Previously, when the server returned an error (e.g. 503 timeout) mid-run, any assistant messages, tool calls, and tool results produced since the user's last prompt were lost. loopAgentSteps created a shallow copy of the initial agent state, so mutations didn't propagate back to the SDK's shared sessionState reference if an error threw up. Now loopAgentSteps mutates the caller's state directly, and the SDK's cancellation path detects runtime progress so the user prompt isn't duplicated.
1 parent 21d5dd3 commit 21769ae

4 files changed

Lines changed: 338 additions & 13 deletions

File tree

packages/agent-runtime/src/__tests__/main-prompt.test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ describe('mainPrompt', () => {
375375
it('should update consecutiveAssistantMessages when new prompt is received', async () => {
376376
const sessionState = getInitialSessionState(mockFileContext)
377377
sessionState.mainAgentState.stepsRemaining = 12
378+
const initialStepsRemaining = sessionState.mainAgentState.stepsRemaining
378379

379380
const action = {
380381
type: 'prompt' as const,
@@ -394,7 +395,7 @@ describe('mainPrompt', () => {
394395

395396
// When there's a new prompt, consecutiveAssistantMessages should be set to 1
396397
expect(newSessionState.mainAgentState.stepsRemaining).toBe(
397-
sessionState.mainAgentState.stepsRemaining - 1,
398+
initialStepsRemaining - 1,
398399
)
399400
})
400401

packages/agent-runtime/src/run-agent-step.ts

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -800,12 +800,13 @@ export async function loopAgentSteps(
800800
return cachedAdditionalToolDefinitions
801801
}
802802

803-
let currentAgentState: AgentState = {
804-
...initialAgentState,
805-
messageHistory: initialMessages,
806-
systemPrompt: system,
807-
toolDefinitions,
808-
}
803+
// Mutate initialAgentState so that in-progress work propagates back to the
804+
// caller's shared reference (e.g. SDK's sessionState.mainAgentState) even if
805+
// an error is thrown before we return.
806+
initialAgentState.messageHistory = initialMessages
807+
initialAgentState.systemPrompt = system
808+
initialAgentState.toolDefinitions = toolDefinitions
809+
let currentAgentState: AgentState = initialAgentState
809810

810811
// Convert tool definitions to Anthropic format for accurate token counting
811812
// Tool definitions are stored as { [name]: { description, inputSchema } }
@@ -908,7 +909,8 @@ export async function loopAgentSteps(
908909
} = programmaticResult
909910
n = generateN
910911

911-
currentAgentState = programmaticAgentState
912+
Object.assign(initialAgentState, programmaticAgentState)
913+
currentAgentState = initialAgentState
912914
totalSteps = stepNumber
913915

914916
shouldEndTurn = endTurn
@@ -989,7 +991,8 @@ export async function loopAgentSteps(
989991
logger.error('No runId found for agent state after finishing agent run')
990992
}
991993

992-
currentAgentState = newAgentState
994+
Object.assign(initialAgentState, newAgentState)
995+
currentAgentState = initialAgentState
993996
shouldEndTurn = llmShouldEndTurn
994997
nResponses = generatedResponses
995998

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt'
2+
import { getInitialSessionState } from '@codebuff/common/types/session-state'
3+
import { getStubProjectFileContext } from '@codebuff/common/util/file'
4+
import { assistantMessage, userMessage } from '@codebuff/common/util/messages'
5+
import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test'
6+
7+
import { CodebuffClient } from '../client'
8+
import * as databaseModule from '../impl/database'
9+
10+
interface ToolCallContentBlock {
11+
type: 'tool-call'
12+
toolCallId: string
13+
toolName: string
14+
input: Record<string, unknown>
15+
}
16+
17+
const setupDatabaseMocks = () => {
18+
spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({
19+
id: 'user-123',
20+
email: 'test@example.com',
21+
discord_id: null,
22+
referral_code: null,
23+
stripe_customer_id: null,
24+
banned: false,
25+
created_at: new Date('2024-01-01T00:00:00Z'),
26+
})
27+
spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null)
28+
spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1')
29+
spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined)
30+
spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1')
31+
}
32+
33+
describe('Error preserves in-progress message history', () => {
34+
afterEach(() => {
35+
mock.restore()
36+
})
37+
38+
it('preserves in-progress assistant work on error (simulated via shared state mutation)', async () => {
39+
setupDatabaseMocks()
40+
41+
// Simulate the agent runtime:
42+
// 1. Mutates the shared session state with the user message and partial work
43+
// 2. Then throws due to a downstream timeout/service error
44+
spyOn(mainPromptModule, 'callMainPrompt').mockImplementation(
45+
async (params: Parameters<typeof mainPromptModule.callMainPrompt>[0]) => {
46+
const history = params.action.sessionState.mainAgentState.messageHistory
47+
48+
// The runtime adds the user message as part of building its initial messages
49+
history.push({
50+
role: 'user',
51+
content: [{ type: 'text', text: 'Fix the bug in auth.ts' }],
52+
tags: ['USER_PROMPT'],
53+
})
54+
55+
// Step 1: assistant responds with a tool call (reading a file)
56+
history.push({
57+
role: 'assistant',
58+
content: [
59+
{ type: 'text', text: 'Let me read the auth file first.' },
60+
{
61+
type: 'tool-call',
62+
toolCallId: 'read-1',
63+
toolName: 'read_files',
64+
input: { paths: ['auth.ts'] },
65+
} as ToolCallContentBlock,
66+
],
67+
})
68+
69+
// Tool result
70+
history.push({
71+
role: 'tool',
72+
toolCallId: 'read-1',
73+
toolName: 'read_files',
74+
content: [
75+
{
76+
type: 'json',
77+
value: [{ path: 'auth.ts', content: 'const auth = ...' }],
78+
},
79+
],
80+
})
81+
82+
// Step 2: assistant continues with another tool call (writing the fix)
83+
history.push({
84+
role: 'assistant',
85+
content: [
86+
{ type: 'text', text: 'Found the issue, writing the fix now.' },
87+
{
88+
type: 'tool-call',
89+
toolCallId: 'write-1',
90+
toolName: 'write_file',
91+
input: { path: 'auth.ts', content: 'const auth = fixed' },
92+
} as ToolCallContentBlock,
93+
],
94+
})
95+
96+
history.push({
97+
role: 'tool',
98+
toolCallId: 'write-1',
99+
toolName: 'write_file',
100+
content: [{ type: 'json', value: { file: 'auth.ts', message: 'File written' } }],
101+
})
102+
103+
// Now simulate a server timeout on the next LLM call
104+
const timeoutError = new Error('Service Unavailable') as Error & {
105+
statusCode: number
106+
responseBody: string
107+
}
108+
timeoutError.statusCode = 503
109+
timeoutError.responseBody = JSON.stringify({
110+
message: 'Request timeout after 30s',
111+
})
112+
throw timeoutError
113+
},
114+
)
115+
116+
const client = new CodebuffClient({ apiKey: 'test-key' })
117+
const result = await client.run({
118+
agent: 'base2',
119+
prompt: 'Fix the bug in auth.ts',
120+
})
121+
122+
// Error output with correct status code
123+
expect(result.output.type).toBe('error')
124+
const errorOutput = result.output as {
125+
type: 'error'
126+
message: string
127+
statusCode?: number
128+
}
129+
expect(errorOutput.statusCode).toBe(503)
130+
131+
const history = result.sessionState!.mainAgentState.messageHistory
132+
133+
// The user's prompt should appear exactly once
134+
const userPromptMessages = history.filter(
135+
(m) =>
136+
m.role === 'user' &&
137+
(m.content as Array<{ type: string; text?: string }>).some(
138+
(c) => c.type === 'text' && c.text?.includes('Fix the bug'),
139+
),
140+
)
141+
expect(userPromptMessages.length).toBe(1)
142+
143+
// Assistant text messages from both steps should be preserved
144+
const firstAssistantText = history.find(
145+
(m) =>
146+
m.role === 'assistant' &&
147+
(m.content as Array<{ type: string; text?: string }>).some(
148+
(c) => c.type === 'text' && c.text?.includes('read the auth file'),
149+
),
150+
)
151+
expect(firstAssistantText).toBeDefined()
152+
153+
const secondAssistantText = history.find(
154+
(m) =>
155+
m.role === 'assistant' &&
156+
(m.content as Array<{ type: string; text?: string }>).some(
157+
(c) => c.type === 'text' && c.text?.includes('writing the fix'),
158+
),
159+
)
160+
expect(secondAssistantText).toBeDefined()
161+
162+
// Both tool calls and both tool results should be preserved
163+
const readToolCall = history.find(
164+
(m) =>
165+
m.role === 'assistant' &&
166+
(m.content as Array<{ type: string; toolCallId?: string }>).some(
167+
(c) => c.type === 'tool-call' && c.toolCallId === 'read-1',
168+
),
169+
)
170+
expect(readToolCall).toBeDefined()
171+
172+
const writeToolCall = history.find(
173+
(m) =>
174+
m.role === 'assistant' &&
175+
(m.content as Array<{ type: string; toolCallId?: string }>).some(
176+
(c) => c.type === 'tool-call' && c.toolCallId === 'write-1',
177+
),
178+
)
179+
expect(writeToolCall).toBeDefined()
180+
181+
const readToolResult = history.find(
182+
(m) => m.role === 'tool' && m.toolCallId === 'read-1',
183+
)
184+
expect(readToolResult).toBeDefined()
185+
186+
const writeToolResult = history.find(
187+
(m) => m.role === 'tool' && m.toolCallId === 'write-1',
188+
)
189+
expect(writeToolResult).toBeDefined()
190+
})
191+
192+
it('a subsequent run after error includes the preserved in-progress history', async () => {
193+
setupDatabaseMocks()
194+
195+
// Run 1: agent does some work then hits an error
196+
spyOn(mainPromptModule, 'callMainPrompt').mockImplementation(
197+
async (params: Parameters<typeof mainPromptModule.callMainPrompt>[0]) => {
198+
const history = params.action.sessionState.mainAgentState.messageHistory
199+
200+
history.push({
201+
role: 'user',
202+
content: [{ type: 'text', text: 'Investigate the login bug' }],
203+
tags: ['USER_PROMPT'],
204+
})
205+
history.push(assistantMessage('I found the problem in auth.ts on line 42.'))
206+
history.push({
207+
role: 'assistant',
208+
content: [
209+
{
210+
type: 'tool-call',
211+
toolCallId: 'read-login',
212+
toolName: 'read_files',
213+
input: { paths: ['login.ts'] },
214+
} as ToolCallContentBlock,
215+
],
216+
})
217+
history.push({
218+
role: 'tool',
219+
toolCallId: 'read-login',
220+
toolName: 'read_files',
221+
content: [{ type: 'json', value: [{ path: 'login.ts', content: 'login code' }] }],
222+
})
223+
224+
const error = new Error('Service Unavailable') as Error & {
225+
statusCode: number
226+
}
227+
error.statusCode = 503
228+
throw error
229+
},
230+
)
231+
232+
const client = new CodebuffClient({ apiKey: 'test-key' })
233+
const firstResult = await client.run({
234+
agent: 'base2',
235+
prompt: 'Investigate the login bug',
236+
})
237+
238+
expect(firstResult.output.type).toBe('error')
239+
240+
// Run 2: use the failed run as previousRun
241+
mock.restore()
242+
setupDatabaseMocks()
243+
244+
let historyReceivedByRuntime: unknown[] | undefined
245+
spyOn(mainPromptModule, 'callMainPrompt').mockImplementation(
246+
async (params: Parameters<typeof mainPromptModule.callMainPrompt>[0]) => {
247+
const { sendAction, promptId } = params
248+
historyReceivedByRuntime = [
249+
...params.action.sessionState.mainAgentState.messageHistory,
250+
]
251+
252+
const responseSessionState = getInitialSessionState(
253+
getStubProjectFileContext(),
254+
)
255+
responseSessionState.mainAgentState.messageHistory = [
256+
...params.action.sessionState.mainAgentState.messageHistory,
257+
userMessage('Now try again'),
258+
assistantMessage('Continuing with the fix.'),
259+
]
260+
261+
await sendAction({
262+
action: {
263+
type: 'prompt-response',
264+
promptId,
265+
sessionState: responseSessionState,
266+
output: { type: 'lastMessage', value: [] },
267+
},
268+
})
269+
270+
return {
271+
sessionState: responseSessionState,
272+
output: { type: 'lastMessage' as const, value: [] },
273+
}
274+
},
275+
)
276+
277+
const secondResult = await client.run({
278+
agent: 'base2',
279+
prompt: 'Now try again',
280+
previousRun: firstResult,
281+
})
282+
283+
// The runtime should have received history containing the work from the first run
284+
expect(historyReceivedByRuntime).toBeDefined()
285+
const receivedReadCall = historyReceivedByRuntime!.find(
286+
(m) =>
287+
(m as { role: string }).role === 'assistant' &&
288+
((m as { content: Array<{ type: string; toolCallId?: string }> })
289+
.content ?? []).some(
290+
(c) => c.type === 'tool-call' && c.toolCallId === 'read-login',
291+
),
292+
)
293+
expect(receivedReadCall).toBeDefined()
294+
295+
const receivedToolResult = historyReceivedByRuntime!.find(
296+
(m) =>
297+
(m as { role: string }).role === 'tool' &&
298+
(m as { toolCallId: string }).toolCallId === 'read-login',
299+
)
300+
expect(receivedToolResult).toBeDefined()
301+
302+
// Final result should preserve history
303+
const finalHistory = secondResult.sessionState!.mainAgentState.messageHistory
304+
const finalReadCall = finalHistory.find(
305+
(m) =>
306+
m.role === 'assistant' &&
307+
(m.content as Array<{ type: string; toolCallId?: string }>).some(
308+
(c) => c.type === 'tool-call' && c.toolCallId === 'read-login',
309+
),
310+
)
311+
expect(finalReadCall).toBeDefined()
312+
})
313+
})

sdk/src/run.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,24 @@ async function runOnce({
282282
}
283283
}
284284

285+
// The agent runtime mutates sessionState.mainAgentState as it progresses, so any
286+
// messages added beyond this baseline reflect in-progress work that should be preserved.
287+
const initialHistoryLength = sessionState.mainAgentState.messageHistory.length
288+
285289
/** Calculates the current session state if cancelled.
286290
*
287-
* This is used when callMainPrompt throws an error (the server never processed the request).
288-
* We need to add the user's message here since the server didn't get a chance to add it.
291+
* This is used when callMainPrompt throws an error. If the agent runtime made
292+
* any progress (added messages to the shared session state), those messages are
293+
* preserved. Otherwise the user's message is added so it isn't lost.
289294
*/
290295
function getCancelledSessionState(message: string): SessionState {
291296
const state = cloneDeep(sessionState)
292297

293-
// Add the user's message since the server never processed it
294-
if (prompt || preparedContent) {
298+
const runtimeMadeProgress =
299+
state.mainAgentState.messageHistory.length > initialHistoryLength
300+
301+
// Only add the user's message if the runtime didn't get a chance to add it.
302+
if (!runtimeMadeProgress && (prompt || preparedContent)) {
295303
state.mainAgentState.messageHistory.push({
296304
role: 'user' as const,
297305
content: buildUserMessageContent(prompt, params, preparedContent),

0 commit comments

Comments
 (0)