Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-continuation-chunk-emission.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@tanstack/ai': patch
---

Emit TOOL_CALL_START and TOOL_CALL_ARGS for pending tool calls during continuation re-executions
31 changes: 31 additions & 0 deletions packages/typescript/ai/src/activities/chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,13 @@ class TextEngine<
needsClientExecution: executionResult.needsClientExecution,
})

// Build args lookup so buildToolResultChunks can emit TOOL_CALL_START +
// TOOL_CALL_ARGS before TOOL_CALL_END during continuation re-executions.
const argsMap = new Map<string, string>()
for (const tc of pendingToolCalls) {
argsMap.set(tc.id, tc.function.arguments)
}

if (
executionResult.needsApproval.length > 0 ||
executionResult.needsClientExecution.length > 0
Expand All @@ -699,6 +706,7 @@ class TextEngine<
for (const chunk of this.buildToolResultChunks(
executionResult.results,
finishEvent,
argsMap,
)) {
yield chunk
}
Expand All @@ -725,6 +733,7 @@ class TextEngine<
const toolResultChunks = this.buildToolResultChunks(
executionResult.results,
finishEvent,
argsMap,
)

for (const chunk of toolResultChunks) {
Expand Down Expand Up @@ -997,12 +1006,34 @@ class TextEngine<
private buildToolResultChunks(
results: Array<ToolResult>,
finishEvent: RunFinishedEvent,
argsMap?: Map<string, string>,
): Array<StreamChunk> {
const chunks: Array<StreamChunk> = []

for (const result of results) {
const content = JSON.stringify(result.result)

// Emit TOOL_CALL_START + TOOL_CALL_ARGS before TOOL_CALL_END so that
// the client can reconstruct the full tool call during continuations.
if (argsMap) {
chunks.push({
type: 'TOOL_CALL_START',
timestamp: Date.now(),
model: finishEvent.model,
toolCallId: result.toolCallId,
toolName: result.toolName,
})

const args = argsMap.get(result.toolCallId) ?? '{}'
chunks.push({
type: 'TOOL_CALL_ARGS',
timestamp: Date.now(),
model: finishEvent.model,
toolCallId: result.toolCallId,
delta: args,
})
}

chunks.push({
type: 'TOOL_CALL_END',
timestamp: Date.now(),
Expand Down
225 changes: 225 additions & 0 deletions packages/typescript/ai/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,231 @@ describe('chat()', () => {
expect(executeSpy).not.toHaveBeenCalled()
expect(calls).toHaveLength(1)
})

it('should emit TOOL_CALL_START and TOOL_CALL_ARGS before TOOL_CALL_END for pending tool calls', async () => {
const executeSpy = vi.fn().mockReturnValue({ temp: 72 })

const { adapter } = createMockAdapter({
iterations: [
// After pending tool is executed, the engine calls the adapter for the next response
[
ev.runStarted(),
ev.textStart(),
ev.textContent('72F in NYC'),
ev.textEnd(),
ev.runFinished('stop'),
],
],
})

const stream = chat({
adapter,
messages: [
{ role: 'user', content: 'Weather?' },
{
role: 'assistant',
content: 'Let me check.',
toolCalls: [
{
id: 'call_1',
type: 'function' as const,
function: { name: 'getWeather', arguments: '{"city":"NYC"}' },
},
],
},
// No tool result message -> pending!
],
tools: [serverTool('getWeather', executeSpy)],
})

const chunks = await collectChunks(stream as AsyncIterable<StreamChunk>)

// Tool should have been executed
expect(executeSpy).toHaveBeenCalledTimes(1)

// The continuation re-execution should emit the full chunk sequence:
// TOOL_CALL_START -> TOOL_CALL_ARGS -> TOOL_CALL_END
// Without the fix, only TOOL_CALL_END is emitted, causing the client
// to store the tool call with empty arguments {}.
const toolStartChunks = chunks.filter(
(c) =>
c.type === 'TOOL_CALL_START' && (c as any).toolCallId === 'call_1',
)
expect(toolStartChunks).toHaveLength(1)
expect((toolStartChunks[0] as any).toolName).toBe('getWeather')

const toolArgsChunks = chunks.filter(
(c) =>
c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === 'call_1',
)
expect(toolArgsChunks).toHaveLength(1)
expect((toolArgsChunks[0] as any).delta).toBe('{"city":"NYC"}')

const toolEndChunks = chunks.filter(
(c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_1',
)
expect(toolEndChunks).toHaveLength(1)

// Verify ordering: START before ARGS before END
const startIdx = chunks.indexOf(toolStartChunks[0]!)
const argsIdx = chunks.indexOf(toolArgsChunks[0]!)
const endIdx = chunks.indexOf(toolEndChunks[0]!)
expect(startIdx).toBeLessThan(argsIdx)
expect(argsIdx).toBeLessThan(endIdx)
})

it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for each pending tool call in a batch', async () => {
const weatherSpy = vi.fn().mockReturnValue({ temp: 72 })
const timeSpy = vi.fn().mockReturnValue({ time: '3pm' })

const { adapter } = createMockAdapter({
iterations: [
[
ev.runStarted(),
ev.textStart(),
ev.textContent('Done.'),
ev.textEnd(),
ev.runFinished('stop'),
],
],
})

const stream = chat({
adapter,
messages: [
{ role: 'user', content: 'Weather and time?' },
{
role: 'assistant',
content: '',
toolCalls: [
{
id: 'call_weather',
type: 'function' as const,
function: { name: 'getWeather', arguments: '{"city":"NYC"}' },
},
{
id: 'call_time',
type: 'function' as const,
function: { name: 'getTime', arguments: '{"tz":"EST"}' },
},
],
},
// No tool results -> both pending
],
tools: [
serverTool('getWeather', weatherSpy),
serverTool('getTime', timeSpy),
],
})

const chunks = await collectChunks(stream as AsyncIterable<StreamChunk>)

// Both tools should have been executed
expect(weatherSpy).toHaveBeenCalledTimes(1)
expect(timeSpy).toHaveBeenCalledTimes(1)

// Each pending tool should get the full START -> ARGS -> END sequence
for (const { id, name, args } of [
{ id: 'call_weather', name: 'getWeather', args: '{"city":"NYC"}' },
{ id: 'call_time', name: 'getTime', args: '{"tz":"EST"}' },
]) {
const starts = chunks.filter(
(c) => c.type === 'TOOL_CALL_START' && (c as any).toolCallId === id,
)
expect(starts).toHaveLength(1)
expect((starts[0] as any).toolName).toBe(name)

const argChunks = chunks.filter(
(c) => c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === id,
)
expect(argChunks).toHaveLength(1)
expect((argChunks[0] as any).delta).toBe(args)

const ends = chunks.filter(
(c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === id,
)
expect(ends).toHaveLength(1)

// Verify ordering
const startIdx = chunks.indexOf(starts[0]!)
const argsIdx = chunks.indexOf(argChunks[0]!)
const endIdx = chunks.indexOf(ends[0]!)
expect(startIdx).toBeLessThan(argsIdx)
expect(argsIdx).toBeLessThan(endIdx)
}
})

it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for the server tool in a mixed pending batch', async () => {
const weatherSpy = vi.fn().mockReturnValue({ temp: 72 })

const { adapter } = createMockAdapter({ iterations: [] })

const stream = chat({
adapter,
messages: [
{ role: 'user', content: 'Weather and notify?' },
{
role: 'assistant',
content: '',
toolCalls: [
{
id: 'call_server',
type: 'function' as const,
function: { name: 'getWeather', arguments: '{"city":"NYC"}' },
},
{
id: 'call_client',
type: 'function' as const,
function: {
name: 'showNotification',
arguments: '{"message":"done"}',
},
},
],
},
// No tool results -> both pending
],
tools: [
serverTool('getWeather', weatherSpy),
clientTool('showNotification'),
],
})

const chunks = await collectChunks(stream as AsyncIterable<StreamChunk>)

// Server tool should have executed
expect(weatherSpy).toHaveBeenCalledTimes(1)

// The executed server tool should get the full START -> ARGS -> END
const starts = chunks.filter(
(c) =>
c.type === 'TOOL_CALL_START' &&
(c as any).toolCallId === 'call_server',
)
expect(starts).toHaveLength(1)
expect((starts[0] as any).toolName).toBe('getWeather')

const argChunks = chunks.filter(
(c) =>
c.type === 'TOOL_CALL_ARGS' &&
(c as any).toolCallId === 'call_server',
)
expect(argChunks).toHaveLength(1)
expect((argChunks[0] as any).delta).toBe('{"city":"NYC"}')

const ends = chunks.filter(
(c) =>
c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_server',
)
expect(ends).toHaveLength(1)

// Verify ordering
const startIdx = chunks.indexOf(starts[0]!)
const argsIdx = chunks.indexOf(argChunks[0]!)
const endIdx = chunks.indexOf(ends[0]!)
expect(startIdx).toBeLessThan(argsIdx)
expect(argsIdx).toBeLessThan(endIdx)
})
})

// ==========================================================================
Expand Down
Loading