diff --git a/.changeset/fix-continuation-chunk-emission.md b/.changeset/fix-continuation-chunk-emission.md new file mode 100644 index 00000000..a0dbd033 --- /dev/null +++ b/.changeset/fix-continuation-chunk-emission.md @@ -0,0 +1,5 @@ +--- +'@tanstack/ai': patch +--- + +Emit TOOL_CALL_START and TOOL_CALL_ARGS for pending tool calls during continuation re-executions diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index 74c5d5ce..790387e6 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -691,6 +691,13 @@ class TextEngine< needsClientExecution: executionResult.needsClientExecution, }) + // Build args lookup so buildToolResultChunks can emit TOOL_CALL_START + + // TOOL_CALL_ARGS before TOOL_CALL_END during continuation re-executions. + const argsMap = new Map() + for (const tc of pendingToolCalls) { + argsMap.set(tc.id, tc.function.arguments) + } + if ( executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0 @@ -699,6 +706,7 @@ class TextEngine< for (const chunk of this.buildToolResultChunks( executionResult.results, finishEvent, + argsMap, )) { yield chunk } @@ -725,6 +733,7 @@ class TextEngine< const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent, + argsMap, ) for (const chunk of toolResultChunks) { @@ -997,12 +1006,34 @@ class TextEngine< private buildToolResultChunks( results: Array, finishEvent: RunFinishedEvent, + argsMap?: Map, ): Array { const chunks: Array = [] for (const result of results) { const content = JSON.stringify(result.result) + // Emit TOOL_CALL_START + TOOL_CALL_ARGS before TOOL_CALL_END so that + // the client can reconstruct the full tool call during continuations. + if (argsMap) { + chunks.push({ + type: 'TOOL_CALL_START', + timestamp: Date.now(), + model: finishEvent.model, + toolCallId: result.toolCallId, + toolName: result.toolName, + }) + + const args = argsMap.get(result.toolCallId) ?? '{}' + chunks.push({ + type: 'TOOL_CALL_ARGS', + timestamp: Date.now(), + model: finishEvent.model, + toolCallId: result.toolCallId, + delta: args, + }) + } + chunks.push({ type: 'TOOL_CALL_END', timestamp: Date.now(), diff --git a/packages/typescript/ai/tests/chat.test.ts b/packages/typescript/ai/tests/chat.test.ts index 1b9b8eec..84b06d6e 100644 --- a/packages/typescript/ai/tests/chat.test.ts +++ b/packages/typescript/ai/tests/chat.test.ts @@ -646,6 +646,231 @@ describe('chat()', () => { expect(executeSpy).not.toHaveBeenCalled() expect(calls).toHaveLength(1) }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS before TOOL_CALL_END for pending tool calls', async () => { + const executeSpy = vi.fn().mockReturnValue({ temp: 72 }) + + const { adapter } = createMockAdapter({ + iterations: [ + // After pending tool is executed, the engine calls the adapter for the next response + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('72F in NYC'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather?' }, + { + role: 'assistant', + content: 'Let me check.', + toolCalls: [ + { + id: 'call_1', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + ], + }, + // No tool result message -> pending! + ], + tools: [serverTool('getWeather', executeSpy)], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Tool should have been executed + expect(executeSpy).toHaveBeenCalledTimes(1) + + // The continuation re-execution should emit the full chunk sequence: + // TOOL_CALL_START -> TOOL_CALL_ARGS -> TOOL_CALL_END + // Without the fix, only TOOL_CALL_END is emitted, causing the client + // to store the tool call with empty arguments {}. + const toolStartChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_START' && (c as any).toolCallId === 'call_1', + ) + expect(toolStartChunks).toHaveLength(1) + expect((toolStartChunks[0] as any).toolName).toBe('getWeather') + + const toolArgsChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === 'call_1', + ) + expect(toolArgsChunks).toHaveLength(1) + expect((toolArgsChunks[0] as any).delta).toBe('{"city":"NYC"}') + + const toolEndChunks = chunks.filter( + (c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_1', + ) + expect(toolEndChunks).toHaveLength(1) + + // Verify ordering: START before ARGS before END + const startIdx = chunks.indexOf(toolStartChunks[0]!) + const argsIdx = chunks.indexOf(toolArgsChunks[0]!) + const endIdx = chunks.indexOf(toolEndChunks[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for each pending tool call in a batch', async () => { + const weatherSpy = vi.fn().mockReturnValue({ temp: 72 }) + const timeSpy = vi.fn().mockReturnValue({ time: '3pm' }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('Done.'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather and time?' }, + { + role: 'assistant', + content: '', + toolCalls: [ + { + id: 'call_weather', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + { + id: 'call_time', + type: 'function' as const, + function: { name: 'getTime', arguments: '{"tz":"EST"}' }, + }, + ], + }, + // No tool results -> both pending + ], + tools: [ + serverTool('getWeather', weatherSpy), + serverTool('getTime', timeSpy), + ], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Both tools should have been executed + expect(weatherSpy).toHaveBeenCalledTimes(1) + expect(timeSpy).toHaveBeenCalledTimes(1) + + // Each pending tool should get the full START -> ARGS -> END sequence + for (const { id, name, args } of [ + { id: 'call_weather', name: 'getWeather', args: '{"city":"NYC"}' }, + { id: 'call_time', name: 'getTime', args: '{"tz":"EST"}' }, + ]) { + const starts = chunks.filter( + (c) => c.type === 'TOOL_CALL_START' && (c as any).toolCallId === id, + ) + expect(starts).toHaveLength(1) + expect((starts[0] as any).toolName).toBe(name) + + const argChunks = chunks.filter( + (c) => c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === id, + ) + expect(argChunks).toHaveLength(1) + expect((argChunks[0] as any).delta).toBe(args) + + const ends = chunks.filter( + (c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === id, + ) + expect(ends).toHaveLength(1) + + // Verify ordering + const startIdx = chunks.indexOf(starts[0]!) + const argsIdx = chunks.indexOf(argChunks[0]!) + const endIdx = chunks.indexOf(ends[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + } + }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for the server tool in a mixed pending batch', async () => { + const weatherSpy = vi.fn().mockReturnValue({ temp: 72 }) + + const { adapter } = createMockAdapter({ iterations: [] }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather and notify?' }, + { + role: 'assistant', + content: '', + toolCalls: [ + { + id: 'call_server', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + { + id: 'call_client', + type: 'function' as const, + function: { + name: 'showNotification', + arguments: '{"message":"done"}', + }, + }, + ], + }, + // No tool results -> both pending + ], + tools: [ + serverTool('getWeather', weatherSpy), + clientTool('showNotification'), + ], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Server tool should have executed + expect(weatherSpy).toHaveBeenCalledTimes(1) + + // The executed server tool should get the full START -> ARGS -> END + const starts = chunks.filter( + (c) => + c.type === 'TOOL_CALL_START' && + (c as any).toolCallId === 'call_server', + ) + expect(starts).toHaveLength(1) + expect((starts[0] as any).toolName).toBe('getWeather') + + const argChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_ARGS' && + (c as any).toolCallId === 'call_server', + ) + expect(argChunks).toHaveLength(1) + expect((argChunks[0] as any).delta).toBe('{"city":"NYC"}') + + const ends = chunks.filter( + (c) => + c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_server', + ) + expect(ends).toHaveLength(1) + + // Verify ordering + const startIdx = chunks.indexOf(starts[0]!) + const argsIdx = chunks.indexOf(argChunks[0]!) + const endIdx = chunks.indexOf(ends[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + }) }) // ==========================================================================