From 2e6bc9585edbdf37b1698ddb3fb1485a6d961eb4 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 21 Apr 2026 19:42:39 +0200 Subject: [PATCH 01/10] feat(appkit): shared agent types and Databricks adapter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation layer for the agents feature. Adds the portable type surface that every downstream layer builds on, plus the Databricks Model Serving adapter so the agents plugin (later PR) can target workspace-hosted models. `packages/shared/src/agent.ts` — no behavior, just the type vocabulary: `AgentAdapter`, `AgentEvent`, `AgentInput`, `AgentRunContext`, `AgentToolDefinition`, `Message`, `Thread`, `ThreadStore`, `ToolAnnotations`, `ToolCall`, `ToolProvider`, `ResponseStreamEvent`. Exported from the shared barrel. `packages/appkit/src/agents/databricks.ts` — `DatabricksAdapter`: streams OpenAI-compatible completions against a Databricks Model Serving endpoint (raw fetch + SSE, no vendor SDKs). Also ships `createDatabricksModel`, a Vercel-AI-SDK helper that returns a model object you can pass to `streamText`/`useChat`/etc. — handles URL rewriting (`/chat/completions` -> `/invocations`), per-request auth refresh, and tool-name sanitization. `@ai-sdk/openai` is a devDependency consumed by `createDatabricksModel` via dynamic `import()`; consumers who use that helper install it alongside `@databricks/appkit`. Signed-off-by: MarioCadenas --- knip.json | 6 +- packages/appkit/package.json | 5 + packages/appkit/src/agents/databricks.ts | 609 ++++++++++++++++++ .../src/agents/tests/databricks.test.ts | 486 ++++++++++++++ packages/appkit/tsdown.config.ts | 2 +- packages/shared/src/agent.ts | 212 ++++++ packages/shared/src/index.ts | 1 + pnpm-lock.yaml | 3 +- 8 files changed, 1320 insertions(+), 4 deletions(-) create mode 100644 packages/appkit/src/agents/databricks.ts create mode 100644 packages/appkit/src/agents/tests/databricks.test.ts create mode 100644 packages/shared/src/agent.ts diff --git a/knip.json b/knip.json index b777d8c2a..036404ee4 100644 --- a/knip.json +++ b/knip.json @@ -7,7 +7,6 @@ "docs" ], "workspaces": { - "packages/appkit": {}, "packages/appkit-ui": { "ignoreDependencies": ["tailwindcss", "tw-animate-css"] } @@ -17,6 +16,11 @@ "**/*.example.tsx", "**/*.css", "packages/appkit/src/plugins/vector-search/**", + "packages/appkit/src/plugin/index.ts", + "packages/appkit/src/plugins/agents/index.ts", + "packages/appkit/src/plugins/agents/tools/index.ts", + "packages/appkit/src/plugins/agents/from-plugin.ts", + "packages/appkit/src/plugins/agents/load-agents.ts", "template/**", "tools/**", "docs/**" diff --git a/packages/appkit/package.json b/packages/appkit/package.json index bddc99a8b..f2f8e366e 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -33,6 +33,10 @@ "development": "./src/beta.ts", "default": "./dist/beta.js" }, + "./agents/databricks": { + "development": "./src/agents/databricks.ts", + "default": "./dist/agents/databricks.js" + }, "./type-generator": { "types": "./dist/type-generator/index.d.ts", "development": "./src/type-generator/index.ts", @@ -100,6 +104,7 @@ "exports": { ".": "./dist/index.js", "./beta": "./dist/beta.js", + "./agents/databricks": "./dist/agents/databricks.js", "./dist/shared/src/plugin": "./dist/shared/src/plugin.d.ts", "./type-generator": "./dist/type-generator/index.js", "./package.json": "./package.json" diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts new file mode 100644 index 000000000..b74053eef --- /dev/null +++ b/packages/appkit/src/agents/databricks.ts @@ -0,0 +1,609 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; +import { stream as servingStream } from "../connectors/serving/client"; + +/** + * Transport shim: given an OpenAI-compatible request body, returns the raw + * SSE byte stream from the serving endpoint. Injected at construction time so + * callers can swap in the workspace SDK (factory paths), a bare `fetch` + * (the raw constructor), or a test fake. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Escape-hatch options: provide an `endpointUrl` + `authenticate()` and the + * adapter uses a bare `fetch()` to call it. Useful for tests and for pointing + * the adapter at non-workspace endpoints (reverse proxies, mocks). + */ +interface RawFetchAdapterOptions { + endpointUrl: string; + authenticate: () => Promise>; + maxSteps?: number; + maxTokens?: number; +} + +/** + * Preferred options: caller provides the transport function directly. + * The `fromServingEndpoint` / `fromModelServing` factories use this to route + * through `connectors/serving/stream`, which centralises URL encoding, auth + * via the SDK's `apiClient.request`, and any future retries/telemetry. + */ +interface StreamBodyAdapterOptions { + streamBody: StreamBody; + maxSteps?: number; + maxTokens?: number; +} + +type DatabricksAdapterOptions = + | RawFetchAdapterOptions + | StreamBodyAdapterOptions; + +function isStreamBodyOptions( + o: DatabricksAdapterOptions, +): o is StreamBodyAdapterOptions { + return "streamBody" in o; +} + +/** + * Duck-typed subset of the Databricks SDK `WorkspaceClient`. Callers of + * `fromServingEndpoint` and `fromModelServing` pass a real `WorkspaceClient`, + * but we only need the `apiClient.request` surface — so we declare the minimal + * interface rather than importing the SDK type directly. This keeps the adapter + * free of a hard compile-time dependency on `@databricks/sdk-experimental`. + */ +interface WorkspaceClientLike { + apiClient: { + request(options: Record): Promise; + }; +} + +interface ServingEndpointOptions { + workspaceClient: WorkspaceClientLike; + endpointName: string; + maxSteps?: number; + maxTokens?: number; +} + +interface ModelServingOptions { + maxSteps?: number; + maxTokens?: number; + workspaceClient?: WorkspaceClientLike; +} + +interface OpenAIMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string | null; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface OpenAIToolCall { + id: string; + type: "function"; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: unknown; + }; +} + +interface DeltaToolCall { + index: number; + id?: string; + type?: string; + function?: { name?: string; arguments?: string }; +} + +/** + * Adapter that talks directly to Databricks Model Serving `/invocations` endpoint. + * + * No dependency on the Vercel AI SDK or LangChain. Uses raw `fetch()` to POST + * OpenAI-compatible payloads and parses the SSE stream itself. Calls + * `authenticate()` per-request so tokens are always fresh. + * + * Handles both structured `tool_calls` responses and text-based tool call + * fallback parsing for models that output tool calls as text. + * + * @example Using the factory (recommended) + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { DatabricksAdapter } from "@databricks/appkit/agents/databricks"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const adapter = DatabricksAdapter.fromServingEndpoint({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + * + * @example Using the raw constructor + * ```ts + * const adapter = new DatabricksAdapter({ + * endpointUrl: "https://host/serving-endpoints/my-endpoint/invocations", + * authenticate: async () => ({ Authorization: `Bearer ${token}` }), + * }); + * ``` + */ +export class DatabricksAdapter implements AgentAdapter { + private streamBody: StreamBody; + private maxSteps: number; + private maxTokens: number; + + constructor(options: DatabricksAdapterOptions) { + this.maxSteps = options.maxSteps ?? 10; + this.maxTokens = options.maxTokens ?? 4096; + + if (isStreamBodyOptions(options)) { + this.streamBody = options.streamBody; + } else { + const { endpointUrl, authenticate } = options; + this.streamBody = async (body, signal) => { + const authHeaders = await authenticate(); + const response = await fetch(endpointUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...authHeaders, + }, + body: JSON.stringify(body), + signal, + }); + if (!response.ok) { + const errorText = await response.text().catch(() => "Unknown error"); + throw new Error( + `Databricks API error (${response.status}): ${errorText}`, + ); + } + if (!response.body) throw new Error("No response body"); + return response.body; + }; + } + } + + /** + * Creates a DatabricksAdapter for a Databricks Model Serving endpoint. + * + * Routes through the shared `connectors/serving/stream` helper, which + * delegates to the SDK's `apiClient.request({ raw: true })`. That gives the + * adapter centralised URL encoding + authentication with the rest of the + * serving surface — no bespoke `fetch()` + `authenticate()` plumbing. + */ + static async fromServingEndpoint( + options: ServingEndpointOptions, + ): Promise { + const { workspaceClient, endpointName, maxSteps, maxTokens } = options; + return new DatabricksAdapter({ + streamBody: (body) => + // Cast through the structural shape: the connector types + // `workspaceClient` as the SDK's concrete `WorkspaceClient`, but we + // only need `apiClient.request`. + servingStream( + workspaceClient as unknown as Parameters[0], + endpointName, + body, + ), + maxSteps, + maxTokens, + }); + } + + /** + * Creates a DatabricksAdapter from a Model Serving endpoint name. + * Auto-creates a WorkspaceClient internally. Reads the endpoint name + * from the argument or the `DATABRICKS_AGENT_ENDPOINT` env var. + * + * @example + * ```ts + * // Reads endpoint from DATABRICKS_AGENT_ENDPOINT env var + * const adapter = await DatabricksAdapter.fromModelServing(); + * + * // Explicit endpoint + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint"); + * + * // With options + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint", { + * maxSteps: 5, + * maxTokens: 2048, + * }); + * ``` + */ + static async fromModelServing( + endpointName?: string, + options?: ModelServingOptions, + ): Promise { + const resolvedEndpoint = + endpointName ?? process.env.DATABRICKS_AGENT_ENDPOINT; + + if (!resolvedEndpoint) { + throw new Error( + "No endpoint name provided and DATABRICKS_AGENT_ENDPOINT env var is not set. " + + "Pass an endpoint name or set the environment variable.", + ); + } + + let workspaceClient: WorkspaceClientLike | undefined = + options?.workspaceClient; + if (!workspaceClient) { + const sdk = await import("@databricks/sdk-experimental"); + workspaceClient = new sdk.WorkspaceClient( + {}, + ) as unknown as WorkspaceClientLike; + } + + return DatabricksAdapter.fromServingEndpoint({ + workspaceClient, + endpointName: resolvedEndpoint, + maxSteps: options?.maxSteps, + maxTokens: options?.maxTokens, + }); + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + // Databricks API requires tool names to match [a-zA-Z0-9_-]. + // Our tool names use dots (e.g. "analytics.query"), so we swap dots + // for double-underscores in the wire format and map back on receipt. + const nameToWire = new Map(); + const wireToName = new Map(); + for (const tool of input.tools) { + const wire = tool.name.replace(/\./g, "__"); + nameToWire.set(tool.name, wire); + wireToName.set(wire, tool.name); + } + + const tools = this.buildTools(input.tools, nameToWire); + const messages = this.buildMessages(input.messages); + + yield { type: "status", status: "running" }; + + for (let step = 0; step < this.maxSteps; step++) { + if (context.signal?.aborted) break; + + const { text, toolCalls } = yield* this.streamCompletion( + messages, + tools, + context, + ); + + if (toolCalls.length === 0) { + const parsed = parseTextToolCalls(text); + if (parsed.length > 0) { + yield* this.executeToolCalls(parsed, messages, context); + continue; + } + break; + } + + messages.push({ + role: "assistant", + content: text || null, + tool_calls: toolCalls, + }); + + for (const tc of toolCalls) { + const wireName = tc.function.name; + const originalName = wireToName.get(wireName) ?? wireName; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name: originalName, args }; + + try { + const result = await context.executeTool(originalName, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + } + + private async *streamCompletion( + messages: OpenAIMessage[], + tools: OpenAITool[], + context: AgentRunContext, + ): AsyncGenerator< + AgentEvent, + { text: string; toolCalls: OpenAIToolCall[] }, + unknown + > { + const body: Record = { + messages, + stream: true, + max_tokens: this.maxTokens, + }; + + if (tools.length > 0) { + body.tools = tools; + } + + const responseBody = await this.streamBody(body, context.signal); + const reader = responseBody.getReader(); + + const decoder = new TextDecoder(); + let buffer = ""; + let fullText = ""; + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + + try { + while (true) { + if (context.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed.startsWith("data: ")) continue; + const data = trimmed.slice(6); + if (data === "[DONE]") continue; + + let parsed: any; + try { + parsed = JSON.parse(data); + } catch { + continue; + } + + const delta = parsed.choices?.[0]?.delta; + if (!delta) continue; + + if (delta.content) { + fullText += delta.content; + yield { type: "message_delta" as const, content: delta.content }; + } + + if (delta.tool_calls) { + for (const tc of delta.tool_calls as DeltaToolCall[]) { + const existing = toolCallAccumulator.get(tc.index); + if (existing) { + if (tc.function?.arguments) { + existing.arguments += tc.function.arguments; + } + } else { + toolCallAccumulator.set(tc.index, { + id: tc.id ?? `call_${tc.index}`, + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }); + } + } + } + } + } + } finally { + reader.releaseLock(); + } + + const toolCalls: OpenAIToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments || "{}" }, + })); + + return { text: fullText, toolCalls }; + } + + private async *executeToolCalls( + calls: Array<{ name: string; args: unknown }>, + messages: OpenAIMessage[], + context: AgentRunContext, + ): AsyncGenerator { + const toolCallObjs: OpenAIToolCall[] = calls.map((c, i) => ({ + id: `text_call_${i}`, + type: "function" as const, + function: { + name: c.name, + arguments: JSON.stringify(c.args), + }, + })); + + messages.push({ + role: "assistant", + content: null, + tool_calls: toolCallObjs, + }); + + for (const tc of toolCallObjs) { + const name = tc.function.name; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name, args }; + + try { + const result = await context.executeTool(name, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + + private buildMessages(messages: AgentInput["messages"]): OpenAIMessage[] { + return messages.map((m) => ({ + role: m.role as OpenAIMessage["role"], + content: m.content, + })); + } + + private buildTools( + definitions: AgentToolDefinition[], + nameToWire: Map, + ): OpenAITool[] { + return definitions.map((def) => ({ + type: "function" as const, + function: { + name: nameToWire.get(def.name) ?? def.name, + description: def.description, + parameters: def.parameters, + }, + })); + } +} + +// --------------------------------------------------------------------------- +// Text-based tool call parsing (fallback) +// --------------------------------------------------------------------------- + +/** + * Parses text-based tool calls from model output. + * + * Handles two formats: + * 1. Llama native: `[{"name": "tool_name", "parameters": {"arg": "val"}}]` + * 2. Python-style: `[tool_name(arg1='val1', arg2='val2')]` + */ +export function parseTextToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const trimmed = text.trim(); + + const jsonResult = tryParseLlamaJsonToolCalls(trimmed); + if (jsonResult.length > 0) return jsonResult; + + const pyResult = tryParsePythonStyleToolCalls(trimmed); + if (pyResult.length > 0) return pyResult; + + return []; +} + +function tryParseLlamaJsonToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const match = text.match(/\[\s*\{[\s\S]*\}\s*\]/); + if (!match) return []; + + try { + const parsed = JSON.parse(match[0]); + if (!Array.isArray(parsed)) return []; + + return parsed + .filter( + (item: any) => + typeof item === "object" && + item !== null && + typeof item.name === "string", + ) + .map((item: any) => ({ + name: item.name, + args: item.parameters ?? item.arguments ?? item.args ?? {}, + })); + } catch { + return []; + } +} + +function tryParsePythonStyleToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const pattern = /\[?([a-zA-Z_][\w.]*)\(([^)]*)\)\]?/g; + const results: Array<{ name: string; args: unknown }> = []; + + for (const match of text.matchAll(pattern)) { + const name = match[1]; + const argsStr = match[2]; + + const args: Record = {}; + const argPattern = /(\w+)\s*=\s*(?:'([^']*)'|"([^"]*)"|(\S+))/g; + for (const argMatch of argsStr.matchAll(argPattern)) { + const key = argMatch[1]; + const value = argMatch[2] ?? argMatch[3] ?? argMatch[4]; + args[key] = value; + } + + results.push({ name, args }); + } + + return results; +} diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts new file mode 100644 index 000000000..8a835094e --- /dev/null +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -0,0 +1,486 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { DatabricksAdapter, parseTextToolCalls } from "../databricks"; + +const mockAuthenticate = vi + .fn() + .mockResolvedValue({ Authorization: "Bearer test-token" }); + +function sseChunk(data: string): string { + return `data: ${data}\n\n`; +} + +function textDelta(content: string): string { + return sseChunk( + JSON.stringify({ + choices: [{ delta: { content } }], + }), + ); +} + +function toolCallDelta( + index: number, + id: string | undefined, + name: string | undefined, + args: string, +): string { + return sseChunk( + JSON.stringify({ + choices: [ + { + delta: { + tool_calls: [ + { + index, + ...(id && { id }), + ...(name && { type: "function" }), + function: { + ...(name && { name }), + arguments: args, + }, + }, + ], + }, + }, + ], + }), + ); +} + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function mockFetch(chunks: string[]): typeof globalThis.fetch { + return vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(chunks), + text: () => Promise.resolve(""), + }); +} + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { query: { type: "string" } }, + required: ["query"], + }, + }, + ]; +} + +function createAdapter(overrides?: { + endpointUrl?: string; + authenticate?: () => Promise>; + maxSteps?: number; + maxTokens?: number; +}) { + return new DatabricksAdapter({ + endpointUrl: + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + authenticate: mockAuthenticate, + ...overrides, + }); +} + +describe("DatabricksAdapter", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + mockAuthenticate.mockClear(); + }); + + test("streams text deltas from the model", async () => { + globalThis.fetch = mockFetch([ + textDelta("Hello"), + textDelta(" world"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("calls authenticate() per request for fresh headers", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(mockAuthenticate).toHaveBeenCalledTimes(1); + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + expect(init.headers.Authorization).toBe("Bearer test-token"); + }); + + test("handles structured tool calls and executes them", async () => { + const executeTool = vi.fn().mockResolvedValue([{ trip_id: 1 }]); + + let callCount = 0; + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "call_1", "analytics__query", ""), + toolCallDelta(0, undefined, undefined, '{"query":'), + toolCallDelta(0, undefined, undefined, '"SELECT 1"}'), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta("Here are the results"), + sseChunk("[DONE]"), + ]), + }); + }); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(events).toContainEqual( + expect.objectContaining({ + type: "tool_result", + callId: "call_1", + result: [{ trip_id: 1 }], + }), + ); + + expect(events).toContainEqual({ + type: "message_delta", + content: "Here are the results", + }); + + // authenticate() called once per streamCompletion + expect(mockAuthenticate).toHaveBeenCalledTimes(2); + }); + + test("respects maxSteps limit", async () => { + globalThis.fetch = vi.fn().mockImplementation(() => + Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_loop", + "analytics__query", + '{"query":"SELECT 1"}', + ), + sseChunk("[DONE]"), + ]), + }), + ); + + const adapter = createAdapter({ maxSteps: 2 }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + events.push(event); + } + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + }); + + test("sends correct request to endpoint URL", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [url, init] = (globalThis.fetch as any).mock.calls[0]; + expect(url).toBe( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + ); + + const body = JSON.parse(init.body); + expect(body.stream).toBe(true); + expect(body.tools).toHaveLength(1); + expect(body.tools[0].function.name).toBe("analytics__query"); + expect(body.messages[0]).toEqual({ + role: "user", + content: "Hello", + }); + }); + + test("throws on non-ok response", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: () => Promise.resolve("Unauthorized"), + }); + + const adapter = createAdapter(); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow("Databricks API error (401): Unauthorized"); + }); +}); + +describe("DatabricksAdapter.fromServingEndpoint", () => { + test("routes tool-free chat through apiClient.request with a streaming payload", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my-model", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(apiClient.request).toHaveBeenCalledTimes(1); + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe("/serving-endpoints/my-model/invocations"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload.stream).toBe(true); + // Auth + url encoding are the connector's (and the SDK's) concerns — the + // adapter no longer reaches into the workspace config. + }); + + test("URL-encodes endpoint names with special characters", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my model/with spaces", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/my%20model%2Fwith%20spaces/invocations", + ); + }); +}); + +describe("DatabricksAdapter.fromModelServing", () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + test("reads endpoint from DATABRICKS_AGENT_ENDPOINT env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "my-model"; + + vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn().mockImplementation(() => ({ + apiClient: { request: vi.fn() }, + })), + })); + + const adapter = await DatabricksAdapter.fromModelServing(); + expect(adapter).toBeInstanceOf(DatabricksAdapter); + }); + + test("throws when no endpoint name and no env var", async () => { + delete process.env.DATABRICKS_AGENT_ENDPOINT; + + await expect(DatabricksAdapter.fromModelServing()).rejects.toThrow( + "No endpoint name provided", + ); + }); + + test("explicit endpoint name takes precedence over env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "env-model"; + + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromModelServing("explicit-model", { + workspaceClient: { apiClient }, + }); + + expect(adapter).toBeInstanceOf(DatabricksAdapter); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/explicit-model/invocations", + ); + }); +}); + +describe("parseTextToolCalls", () => { + test("parses Llama JSON format", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "analytics.query", args: { query: "SELECT 1" } }, + ]); + }); + + test("parses multiple Llama JSON tool calls", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}, {"name": "files.uploads.list", "parameters": {}}]'; + const result = parseTextToolCalls(text); + + expect(result).toHaveLength(2); + expect(result[0].name).toBe("analytics.query"); + expect(result[1].name).toBe("files.uploads.list"); + }); + + test("parses Python-style tool calls", () => { + const text = + "[analytics.query(query='SELECT * FROM trips ORDER BY date DESC LIMIT 10')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "analytics.query", + args: { + query: "SELECT * FROM trips ORDER BY date DESC LIMIT 10", + }, + }, + ]); + }); + + test("parses Python-style with multiple args", () => { + const text = + "[files.uploads.read(path='/data/file.csv', encoding='utf-8')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "files.uploads.read", + args: { path: "/data/file.csv", encoding: "utf-8" }, + }, + ]); + }); + + test("returns empty array for plain text", () => { + expect(parseTextToolCalls("Hello, how can I help?")).toEqual([]); + expect(parseTextToolCalls("")).toEqual([]); + expect(parseTextToolCalls("The answer is 42")).toEqual([]); + }); + + test("handles Llama format with 'arguments' key", () => { + const text = + '[{"name": "lakebase.query", "arguments": {"text": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "lakebase.query", args: { text: "SELECT 1" } }, + ]); + }); +}); diff --git a/packages/appkit/tsdown.config.ts b/packages/appkit/tsdown.config.ts index d61e8c534..fdfdd721d 100644 --- a/packages/appkit/tsdown.config.ts +++ b/packages/appkit/tsdown.config.ts @@ -4,7 +4,7 @@ export default defineConfig([ { publint: true, name: "@databricks/appkit", - entry: ["src/index.ts", "src/beta.ts"], + entry: ["src/index.ts", "src/beta.ts", "src/agents/databricks.ts"], outDir: "dist", hash: false, format: "esm", diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts new file mode 100644 index 000000000..c4f76b294 --- /dev/null +++ b/packages/shared/src/agent.ts @@ -0,0 +1,212 @@ +import type { JSONSchema7 } from "json-schema"; + +// --------------------------------------------------------------------------- +// Tool definitions +// --------------------------------------------------------------------------- + +export interface ToolAnnotations { + readOnly?: boolean; + destructive?: boolean; + idempotent?: boolean; + requiresUserContext?: boolean; +} + +export interface AgentToolDefinition { + name: string; + description: string; + parameters: JSONSchema7; + annotations?: ToolAnnotations; +} + +export interface ToolProvider { + getAgentTools(): AgentToolDefinition[]; + executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise; +} + +// --------------------------------------------------------------------------- +// Messages & threads +// --------------------------------------------------------------------------- + +export interface Message { + id: string; + role: "user" | "assistant" | "system" | "tool"; + content: string; + toolCallId?: string; + toolCalls?: ToolCall[]; + createdAt: Date; +} + +export interface ToolCall { + id: string; + name: string; + args: unknown; +} + +export interface Thread { + id: string; + userId: string; + messages: Message[]; + createdAt: Date; + updatedAt: Date; +} + +// --------------------------------------------------------------------------- +// Thread store +// --------------------------------------------------------------------------- + +export interface ThreadStore { + create(userId: string): Promise; + get(threadId: string, userId: string): Promise; + list(userId: string): Promise; + addMessage(threadId: string, userId: string, message: Message): Promise; + delete(threadId: string, userId: string): Promise; +} + +// --------------------------------------------------------------------------- +// Agent events (SSE protocol) +// --------------------------------------------------------------------------- + +export type AgentEvent = + | { type: "message_delta"; content: string } + | { type: "message"; content: string } + | { type: "tool_call"; callId: string; name: string; args: unknown } + | { + type: "tool_result"; + callId: string; + result: unknown; + error?: string; + } + | { type: "thinking"; content: string } + | { + type: "status"; + status: "running" | "waiting" | "complete" | "error"; + error?: string; + } + | { type: "metadata"; data: Record }; + +// --------------------------------------------------------------------------- +// Responses API types (OpenAI-compatible wire format for HTTP boundary) +// Self-contained — no openai package dependency. +// --------------------------------------------------------------------------- + +export interface OutputTextContent { + type: "output_text"; + text: string; +} + +export interface ResponseOutputMessage { + type: "message"; + id: string; + status: "in_progress" | "completed"; + role: "assistant"; + content: OutputTextContent[]; +} + +export interface ResponseFunctionToolCall { + type: "function_call"; + id: string; + call_id: string; + name: string; + arguments: string; +} + +export interface ResponseFunctionCallOutput { + type: "function_call_output"; + id: string; + call_id: string; + output: string; +} + +export type ResponseOutputItem = + | ResponseOutputMessage + | ResponseFunctionToolCall + | ResponseFunctionCallOutput; + +export interface ResponseOutputItemAddedEvent { + type: "response.output_item.added"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseOutputItemDoneEvent { + type: "response.output_item.done"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseTextDeltaEvent { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + sequence_number: number; +} + +export interface ResponseCompletedEvent { + type: "response.completed"; + sequence_number: number; + response: Record; +} + +export interface ResponseErrorEvent { + type: "error"; + error: string; + sequence_number: number; +} + +export interface ResponseFailedEvent { + type: "response.failed"; + sequence_number: number; +} + +export interface AppKitThinkingEvent { + type: "appkit.thinking"; + content: string; + sequence_number: number; +} + +export interface AppKitMetadataEvent { + type: "appkit.metadata"; + data: Record; + sequence_number: number; +} + +export type ResponseStreamEvent = + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseTextDeltaEvent + | ResponseCompletedEvent + | ResponseErrorEvent + | ResponseFailedEvent + | AppKitThinkingEvent + | AppKitMetadataEvent; + +// --------------------------------------------------------------------------- +// Adapter contract +// --------------------------------------------------------------------------- + +export interface AgentInput { + messages: Message[]; + tools: AgentToolDefinition[]; + threadId: string; + signal?: AbortSignal; +} + +export interface AgentRunContext { + executeTool: (name: string, args: unknown) => Promise; + signal?: AbortSignal; +} + +export interface AgentAdapter { + run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator; +} diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts index 627d70d6c..9829729a7 100644 --- a/packages/shared/src/index.ts +++ b/packages/shared/src/index.ts @@ -1,3 +1,4 @@ +export * from "./agent"; export * from "./cache"; export * from "./execute"; export * from "./genie"; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 684f6e2e4..c1d8f247e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -5551,7 +5551,7 @@ packages: basic-ftp@5.0.5: resolution: {integrity: sha512-4Bcg1P8xhUuqcii/S0Z9wiHIrQVPMermM1any+MX5GeGD7faD3/msQUDGLol9wOcz4/jbg/WJnGqoJF6LiBdtg==} engines: {node: '>=10.0.0'} - deprecated: Security vulnerability fixed in 5.2.1, please upgrade + deprecated: Security vulnerability fixed in 5.2.0, please upgrade batch@0.6.1: resolution: {integrity: sha512-x+VAiMRL6UPkx+kudNvxTl6hB2XNNCG2r+7wixVfIYwu/2HKRXimwQyaumLjMveWvT2Hkd/cAJw+QBMfJ/EKVw==} @@ -6665,7 +6665,6 @@ packages: dottie@2.0.6: resolution: {integrity: sha512-iGCHkfUc5kFekGiqhe8B/mdaurD+lakO9txNnTvKtA6PISrw86LgqHvRzWYPyoE2Ph5aMIrCw9/uko6XHTKCwA==} - deprecated: Package no longer supported. Contact Support at https://www.npmjs.com/support for more info. drizzle-orm@0.45.1: resolution: {integrity: sha512-Te0FOdKIistGNPMq2jscdqngBRfBpC8uMFVwqjf6gtTVJHIQ/dosgV/CLBU2N4ZJBsXL5savCba9b0YJskKdcA==} From a0a5d9fe4a016f1977150871bb56d3ee2f6eda61 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 15:27:32 +0200 Subject: [PATCH 02/10] fix(appkit): forward AbortSignal through serving stream and cancel SSE reader - Bridge AbortSignal to SDK CancellationToken via Context on apiClient.request - Pass signal from fromServingEndpoint streamBody into serving stream() - Cancel SSE reader in streamCompletion finally for clean teardown Tests: expect second request arg and Context when signal provided Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 14 +++- .../appkit/src/connectors/serving/client.ts | 70 ++++++++++++++++--- .../connectors/serving/tests/client.test.ts | 19 +++++ 3 files changed, 90 insertions(+), 13 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index b74053eef..a5e01831e 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -199,7 +199,7 @@ export class DatabricksAdapter implements AgentAdapter { ): Promise { const { workspaceClient, endpointName, maxSteps, maxTokens } = options; return new DatabricksAdapter({ - streamBody: (body) => + streamBody: (body, signal) => // Cast through the structural shape: the connector types // `workspaceClient` as the SDK's concrete `WorkspaceClient`, but we // only need `apiClient.request`. @@ -207,6 +207,7 @@ export class DatabricksAdapter implements AgentAdapter { workspaceClient as unknown as Parameters[0], endpointName, body, + signal, ), maxSteps, maxTokens, @@ -434,7 +435,16 @@ export class DatabricksAdapter implements AgentAdapter { } } } finally { - reader.releaseLock(); + try { + await reader.cancel(); + } catch { + // Best-effort: reader may already be closed or the stream errored. + } + try { + reader.releaseLock(); + } catch { + // Lock may already be released after cancel. + } } const toolCalls: OpenAIToolCall[] = Array.from( diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 886d2bb3f..83f065e69 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -1,8 +1,46 @@ -import type { serving, WorkspaceClient } from "@databricks/sdk-experimental"; +import type { + CancellationToken, + serving, + WorkspaceClient, +} from "@databricks/sdk-experimental"; +import { Context } from "@databricks/sdk-experimental"; import { createLogger } from "../../logging/logger"; const logger = createLogger("connectors:serving"); +/** + * Bridges {@link AbortSignal} to the SDK's {@link CancellationToken} so + * `apiClient.request` can abort the outbound HTTP request (and stop pulling + * the SSE body) when the agent run is cancelled. + */ +function cancellationTokenFromAbortSignal( + signal: AbortSignal, +): CancellationToken { + const listeners = new Set<() => void>(); + const fire = () => { + for (const cb of listeners) { + try { + cb(); + } catch { + // ignore listener failures — abort must stay best-effort + } + } + }; + signal.addEventListener("abort", fire, { passive: true }); + + return { + get isCancellationRequested() { + return signal.aborted; + }, + onCancellationRequested(callback: (e?: unknown) => unknown) { + listeners.add(callback as () => void); + if (signal.aborted) { + void callback(); + } + }, + }; +} + /** * Invokes a serving endpoint using the SDK's high-level query API. * Returns a typed QueryEndpointResponse. @@ -35,21 +73,31 @@ export async function stream( client: WorkspaceClient, endpointName: string, body: Record, + signal?: AbortSignal, ): Promise> { const { stream: _stream, ...cleanBody } = body; logger.debug("Streaming from endpoint %s", endpointName); - const response = (await client.apiClient.request({ - path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, - method: "POST", - headers: new Headers({ - "Content-Type": "application/json", - Accept: "text/event-stream", - }), - payload: { ...cleanBody, stream: true }, - raw: true, - })) as { contents: ReadableStream }; + const context = signal + ? new Context({ + cancellationToken: cancellationTokenFromAbortSignal(signal), + }) + : undefined; + + const response = (await client.apiClient.request( + { + path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + method: "POST", + headers: new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }), + payload: { ...cleanBody, stream: true }, + raw: true, + }, + context, + )) as { contents: ReadableStream }; if (!response.contents) { throw new Error("Response body is null — streaming not supported"); diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts index 389585b04..d243621e0 100644 --- a/packages/appkit/src/connectors/serving/tests/client.test.ts +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -1,3 +1,4 @@ +import { Context } from "@databricks/sdk-experimental"; import { afterEach, describe, expect, test, vi } from "vitest"; import { invoke, stream } from "../client"; @@ -109,6 +110,24 @@ describe("Serving Connector", () => { raw: true, payload: expect.objectContaining({ stream: true }), }), + undefined, + ); + }); + + test("passes SDK Context when AbortSignal is provided", async () => { + const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ + contents: new ReadableStream(), + }); + + const controller = new AbortController(); + await stream(client, "my-endpoint", { messages: [] }, controller.signal); + + expect(client.apiClient.request).toHaveBeenCalledWith( + expect.objectContaining({ + path: "/serving-endpoints/my-endpoint/invocations", + }), + expect.any(Context), ); }); From 1af6f2a9b5f687bad281b80a5c0f81b64ae8111f Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 16:38:59 +0200 Subject: [PATCH 03/10] fix(appkit): map tool fields in DatabricksAdapter buildMessages - Forward Message.toolCalls/toolCallId to OpenAI tool_calls/tool_call_id - Encode tool names with same wire map as run() (dots -> __) - Use null assistant content when only tool_calls are present Closes gap for resumed threads and hydrated conversation history. Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 55 ++++++++++++++-- .../src/agents/tests/databricks.test.ts | 65 +++++++++++++++++++ 2 files changed, 114 insertions(+), 6 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index a5e01831e..4fba641d7 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -281,7 +281,7 @@ export class DatabricksAdapter implements AgentAdapter { } const tools = this.buildTools(input.tools, nameToWire); - const messages = this.buildMessages(input.messages); + const messages = this.buildMessages(input.messages, nameToWire); yield { type: "status", status: "running" }; @@ -521,11 +521,54 @@ export class DatabricksAdapter implements AgentAdapter { } } - private buildMessages(messages: AgentInput["messages"]): OpenAIMessage[] { - return messages.map((m) => ({ - role: m.role as OpenAIMessage["role"], - content: m.content, - })); + /** + * Maps AppKit {@link AgentInput} messages into OpenAI-compatible wire messages. + * Preserves multi-turn tool state (`toolCalls` → `tool_calls`, `toolCallId` → + * `tool_call_id`) so resumed threads and hydrated history reach the model. + */ + private buildMessages( + messages: AgentInput["messages"], + nameToWire: Map, + ): OpenAIMessage[] { + const wireToolName = (name: string) => + nameToWire.get(name) ?? name.replace(/\./g, "__"); + + return messages.map((m) => { + let content: string | null = m.content; + if ( + m.role === "assistant" && + m.toolCalls && + m.toolCalls.length > 0 && + (!m.content || m.content.trim() === "") + ) { + content = null; + } + + const out: OpenAIMessage = { + role: m.role as OpenAIMessage["role"], + content, + }; + + if (m.toolCallId) { + out.tool_call_id = m.toolCallId; + } + + if (m.toolCalls && m.toolCalls.length > 0) { + out.tool_calls = m.toolCalls.map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { + name: wireToolName(tc.name), + arguments: + typeof tc.args === "string" + ? tc.args + : JSON.stringify(tc.args ?? {}), + }, + })); + } + + return out; + }); } private buildTools( diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 8a835094e..75deed441 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -282,6 +282,71 @@ describe("DatabricksAdapter", () => { }); }); + test("forwards tool thread fields from input messages to the request body", async () => { + globalThis.fetch = mockFetch([textDelta("Done"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + const threadMessages: Message[] = [ + { id: "1", role: "user", content: "Run SQL", createdAt: new Date() }, + { + id: "2", + role: "assistant", + content: "", + createdAt: new Date(), + toolCalls: [ + { + id: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }, + ], + }, + { + id: "3", + role: "tool", + content: '{"rows":[]}', + createdAt: new Date(), + toolCallId: "call_1", + }, + ]; + + for await (const _ of adapter.run( + { + messages: threadMessages, + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + const body = JSON.parse(init.body); + + expect(body.messages[1]).toEqual({ + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_1", + type: "function", + function: { + name: "analytics__query", + arguments: JSON.stringify({ query: "SELECT 1" }), + }, + }, + ], + }); + + expect(body.messages[2]).toEqual({ + role: "tool", + content: '{"rows":[]}', + tool_call_id: "call_1", + }); + }); + test("throws on non-ok response", async () => { globalThis.fetch = vi.fn().mockResolvedValue({ ok: false, From c0e301ce93443d90a9fa02b17564c1738573b8a1 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 17:39:15 +0200 Subject: [PATCH 04/10] fix(appkit): bound DatabricksAdapter SSE stream buffers (F6) - Cap incomplete SSE tail, each complete line, assistant text, and per-index streamed tool arguments (UTF-16 code unit counts) - Defaults: 1Mi line, 4Mi text, 2Mi tool args; overridable via adapter options - Thread limits through fromServingEndpoint and fromModelServing Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 94 ++++++++++++++++++- .../src/agents/tests/databricks.test.ts | 94 +++++++++++++++++++ 2 files changed, 186 insertions(+), 2 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index 4fba641d7..84f1ea0f4 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -7,6 +7,28 @@ import type { } from "shared"; import { stream as servingStream } from "../connectors/serving/client"; +/** Default cap for a single incomplete SSE line tail (DoS guard). */ +const DEFAULT_MAX_SSE_LINE_CHARS = 1024 * 1024; + +/** Default cap for accumulated assistant text from `delta.content`. */ +const DEFAULT_MAX_STREAM_TEXT_CHARS = 4 * 1024 * 1024; + +/** Default cap for accumulated JSON arguments per streamed tool call index. */ +const DEFAULT_MAX_TOOL_ARGUMENT_CHARS = 2 * 1024 * 1024; + +function throwIfExceedsStreamLimit( + label: string, + currentLength: number, + chunk: string, + max: number, +): void { + if (currentLength + chunk.length > max) { + throw new Error( + `DatabricksAdapter: ${label} exceeds configured limit (${max} UTF-16 code units)`, + ); + } +} + /** * Transport shim: given an OpenAI-compatible request body, returns the raw * SSE byte stream from the serving endpoint. Injected at construction time so @@ -28,6 +50,12 @@ interface RawFetchAdapterOptions { authenticate: () => Promise>; maxSteps?: number; maxTokens?: number; + /** Max length of one SSE line (including an incomplete tail in the buffer). */ + maxSseLineChars?: number; + /** Max total length of assistant `delta.content` across the stream. */ + maxStreamTextChars?: number; + /** Max length of streamed `function.arguments` per tool call index. */ + maxToolArgumentsChars?: number; } /** @@ -40,6 +68,9 @@ interface StreamBodyAdapterOptions { streamBody: StreamBody; maxSteps?: number; maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; } type DatabricksAdapterOptions = @@ -70,12 +101,18 @@ interface ServingEndpointOptions { endpointName: string; maxSteps?: number; maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; } interface ModelServingOptions { maxSteps?: number; maxTokens?: number; workspaceClient?: WorkspaceClientLike; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; } interface OpenAIMessage { @@ -154,10 +191,19 @@ export class DatabricksAdapter implements AgentAdapter { private streamBody: StreamBody; private maxSteps: number; private maxTokens: number; + private maxSseLineChars: number; + private maxStreamTextChars: number; + private maxToolArgumentsChars: number; constructor(options: DatabricksAdapterOptions) { this.maxSteps = options.maxSteps ?? 10; this.maxTokens = options.maxTokens ?? 4096; + this.maxSseLineChars = + options.maxSseLineChars ?? DEFAULT_MAX_SSE_LINE_CHARS; + this.maxStreamTextChars = + options.maxStreamTextChars ?? DEFAULT_MAX_STREAM_TEXT_CHARS; + this.maxToolArgumentsChars = + options.maxToolArgumentsChars ?? DEFAULT_MAX_TOOL_ARGUMENT_CHARS; if (isStreamBodyOptions(options)) { this.streamBody = options.streamBody; @@ -197,7 +243,15 @@ export class DatabricksAdapter implements AgentAdapter { static async fromServingEndpoint( options: ServingEndpointOptions, ): Promise { - const { workspaceClient, endpointName, maxSteps, maxTokens } = options; + const { + workspaceClient, + endpointName, + maxSteps, + maxTokens, + maxSseLineChars, + maxStreamTextChars, + maxToolArgumentsChars, + } = options; return new DatabricksAdapter({ streamBody: (body, signal) => // Cast through the structural shape: the connector types @@ -211,6 +265,9 @@ export class DatabricksAdapter implements AgentAdapter { ), maxSteps, maxTokens, + maxSseLineChars, + maxStreamTextChars, + maxToolArgumentsChars, }); } @@ -262,6 +319,9 @@ export class DatabricksAdapter implements AgentAdapter { endpointName: resolvedEndpoint, maxSteps: options?.maxSteps, maxTokens: options?.maxTokens, + maxSseLineChars: options?.maxSseLineChars, + maxStreamTextChars: options?.maxStreamTextChars, + maxToolArgumentsChars: options?.maxToolArgumentsChars, }); } @@ -395,7 +455,19 @@ export class DatabricksAdapter implements AgentAdapter { const lines = buffer.split("\n"); buffer = lines.pop() ?? ""; + if (buffer.length > this.maxSseLineChars) { + throw new Error( + `DatabricksAdapter: SSE line buffer exceeds configured limit (${this.maxSseLineChars} UTF-16 code units)`, + ); + } + for (const line of lines) { + if (line.length > this.maxSseLineChars) { + throw new Error( + `DatabricksAdapter: SSE line exceeds configured limit (${this.maxSseLineChars} UTF-16 code units)`, + ); + } + const trimmed = line.trim(); if (!trimmed.startsWith("data: ")) continue; const data = trimmed.slice(6); @@ -412,6 +484,12 @@ export class DatabricksAdapter implements AgentAdapter { if (!delta) continue; if (delta.content) { + throwIfExceedsStreamLimit( + "streamed assistant text", + fullText.length, + delta.content, + this.maxStreamTextChars, + ); fullText += delta.content; yield { type: "message_delta" as const, content: delta.content }; } @@ -421,13 +499,25 @@ export class DatabricksAdapter implements AgentAdapter { const existing = toolCallAccumulator.get(tc.index); if (existing) { if (tc.function?.arguments) { + throwIfExceedsStreamLimit( + "tool call arguments", + existing.arguments.length, + tc.function.arguments, + this.maxToolArgumentsChars, + ); existing.arguments += tc.function.arguments; } } else { + const initial = tc.function?.arguments ?? ""; + if (initial.length > this.maxToolArgumentsChars) { + throw new Error( + `DatabricksAdapter: tool call arguments exceed configured limit (${this.maxToolArgumentsChars} UTF-16 code units)`, + ); + } toolCallAccumulator.set(tc.index, { id: tc.id ?? `call_${tc.index}`, name: tc.function?.name ?? "", - arguments: tc.function?.arguments ?? "", + arguments: initial, }); } } diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 75deed441..49f8e91d4 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -93,6 +93,9 @@ function createAdapter(overrides?: { authenticate?: () => Promise>; maxSteps?: number; maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; }) { return new DatabricksAdapter({ endpointUrl: @@ -347,6 +350,97 @@ describe("DatabricksAdapter", () => { }); }); + test("throws when SSE line buffer exceeds maxSseLineChars", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(["no-newline-", "xxxxxxxxxx"]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSseLineChars: 12 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/SSE line buffer exceeds configured limit/); + }); + + test("throws when a complete SSE line exceeds maxSseLineChars", async () => { + const longPayload = `${"x".repeat(30)}\n`; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([longPayload]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSseLineChars: 20 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/SSE line exceeds configured limit/); + }); + + test("throws when streamed assistant text exceeds maxStreamTextChars", async () => { + globalThis.fetch = mockFetch([ + textDelta("abcde"), + textDelta("f"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter({ maxStreamTextChars: 5 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/streamed assistant text exceeds configured limit/); + }); + + test("throws when streamed tool arguments exceed maxToolArgumentsChars", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "c1", "t", '{"a":"'), + toolCallDelta(0, undefined, undefined, 'xxxx"}'), + sseChunk("[DONE]"), + ]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxToolArgumentsChars: 8 }); + + await expect(async () => { + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: [ + { + name: "t", + description: "x", + parameters: { type: "object", properties: {} }, + }, + ], + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + // drain + } + }).rejects.toThrow(/tool call arguments exceed/); + }); + test("throws on non-ok response", async () => { globalThis.fetch = vi.fn().mockResolvedValue({ ok: false, From 39075acc297baa08a23203b59c8040457a21f484 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 18:20:53 +0200 Subject: [PATCH 05/10] fix(appkit): wire-format tool names for text-parsed Databricks calls executeToolCalls now maps canonical tool names through nameToWire for messages[].tool_calls while keeping dotted names for tool_call events and executeTool. Test: Llama-json text path asserts second POST uses analytics__query. Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 17 +++-- .../src/agents/tests/databricks.test.ts | 67 +++++++++++++++++++ 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index 84f1ea0f4..0090fc977 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -357,7 +357,7 @@ export class DatabricksAdapter implements AgentAdapter { if (toolCalls.length === 0) { const parsed = parseTextToolCalls(text); if (parsed.length > 0) { - yield* this.executeToolCalls(parsed, messages, context); + yield* this.executeToolCalls(parsed, messages, context, nameToWire); continue; } break; @@ -552,12 +552,16 @@ export class DatabricksAdapter implements AgentAdapter { calls: Array<{ name: string; args: unknown }>, messages: OpenAIMessage[], context: AgentRunContext, + nameToWire: Map, ): AsyncGenerator { + const wireToolName = (name: string) => + nameToWire.get(name) ?? name.replace(/\./g, "__"); + const toolCallObjs: OpenAIToolCall[] = calls.map((c, i) => ({ id: `text_call_${i}`, type: "function" as const, function: { - name: c.name, + name: wireToolName(c.name), arguments: JSON.stringify(c.args), }, })); @@ -568,8 +572,9 @@ export class DatabricksAdapter implements AgentAdapter { tool_calls: toolCallObjs, }); - for (const tc of toolCallObjs) { - const name = tc.function.name; + for (let i = 0; i < toolCallObjs.length; i++) { + const tc = toolCallObjs[i]; + const originalName = calls[i]?.name ?? tc.function.name; let args: unknown; try { args = JSON.parse(tc.function.arguments); @@ -577,10 +582,10 @@ export class DatabricksAdapter implements AgentAdapter { args = {}; } - yield { type: "tool_call", callId: tc.id, name, args }; + yield { type: "tool_call", callId: tc.id, name: originalName, args }; try { - const result = await context.executeTool(name, args); + const result = await context.executeTool(originalName, args); const resultStr = typeof result === "string" ? result : JSON.stringify(result); diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 49f8e91d4..01475b66f 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -221,6 +221,73 @@ describe("DatabricksAdapter", () => { expect(mockAuthenticate).toHaveBeenCalledTimes(2); }); + test("text-parsed tool calls use wire names on follow-up requests", async () => { + const executeTool = vi.fn().mockResolvedValue({ ok: true }); + let callCount = 0; + + const llamaToolJson = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta(llamaToolJson), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([textDelta("Done."), sseChunk("[DONE]")]), + }); + }); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + // drain + } + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + const [, secondInit] = (globalThis.fetch as any).mock.calls[1]; + const secondBody = JSON.parse(secondInit.body); + + expect(secondBody.messages[1]).toEqual({ + role: "assistant", + content: null, + tool_calls: [ + { + id: "text_call_0", + type: "function", + function: { + name: "analytics__query", + arguments: JSON.stringify({ query: "SELECT 1" }), + }, + }, + ], + }); + + expect(secondBody.messages[2]).toEqual({ + role: "tool", + content: JSON.stringify({ ok: true }), + tool_call_id: "text_call_0", + }); + }); + test("respects maxSteps limit", async () => { globalThis.fetch = vi.fn().mockImplementation(() => Promise.resolve({ From fb5ac8c8c4ce69bd6e73d51b1d3aa42bbf31d8b2 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 19:01:19 +0200 Subject: [PATCH 06/10] feat(appkit): use DATABRICKS_SERVING_ENDPOINT_NAME in fromModelServing Matches the serving plugin env var name so deployments configure one variable. Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 10 +++++----- packages/appkit/src/agents/tests/databricks.test.ts | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index 0090fc977..fc3ce176f 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -274,11 +274,11 @@ export class DatabricksAdapter implements AgentAdapter { /** * Creates a DatabricksAdapter from a Model Serving endpoint name. * Auto-creates a WorkspaceClient internally. Reads the endpoint name - * from the argument or the `DATABRICKS_AGENT_ENDPOINT` env var. + * from the argument or the `DATABRICKS_SERVING_ENDPOINT_NAME` env var. * * @example * ```ts - * // Reads endpoint from DATABRICKS_AGENT_ENDPOINT env var + * // Reads endpoint from DATABRICKS_SERVING_ENDPOINT_NAME env var * const adapter = await DatabricksAdapter.fromModelServing(); * * // Explicit endpoint @@ -296,12 +296,12 @@ export class DatabricksAdapter implements AgentAdapter { options?: ModelServingOptions, ): Promise { const resolvedEndpoint = - endpointName ?? process.env.DATABRICKS_AGENT_ENDPOINT; + endpointName ?? process.env.DATABRICKS_SERVING_ENDPOINT_NAME; if (!resolvedEndpoint) { throw new Error( - "No endpoint name provided and DATABRICKS_AGENT_ENDPOINT env var is not set. " + - "Pass an endpoint name or set the environment variable.", + "No endpoint name provided and DATABRICKS_SERVING_ENDPOINT_NAME env var is not set. " + + "Pass an endpoint name or set DATABRICKS_SERVING_ENDPOINT_NAME.", ); } diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 01475b66f..b6a6b0a61 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -595,8 +595,8 @@ describe("DatabricksAdapter.fromModelServing", () => { process.env = originalEnv; }); - test("reads endpoint from DATABRICKS_AGENT_ENDPOINT env var", async () => { - process.env.DATABRICKS_AGENT_ENDPOINT = "my-model"; + test("reads endpoint from DATABRICKS_SERVING_ENDPOINT_NAME env var", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-model"; vi.mock("@databricks/sdk-experimental", () => ({ WorkspaceClient: vi.fn().mockImplementation(() => ({ @@ -609,7 +609,7 @@ describe("DatabricksAdapter.fromModelServing", () => { }); test("throws when no endpoint name and no env var", async () => { - delete process.env.DATABRICKS_AGENT_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; await expect(DatabricksAdapter.fromModelServing()).rejects.toThrow( "No endpoint name provided", @@ -617,7 +617,7 @@ describe("DatabricksAdapter.fromModelServing", () => { }); test("explicit endpoint name takes precedence over env var", async () => { - process.env.DATABRICKS_AGENT_ENDPOINT = "env-model"; + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "env-model"; const apiClient = { request: vi.fn().mockResolvedValue({ From 71dd34908eab015d6146b91c96715bf1c65e3543 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Mon, 4 May 2026 19:14:41 +0200 Subject: [PATCH 07/10] refactor(appkit): export Databricks adapter only via @databricks/appkit/beta Remove ./agents/databricks package export; rely on beta entry and tsdown beta chunk. Update JSDoc example import accordingly. Signed-off-by: MarioCadenas --- packages/appkit/package.json | 5 ----- packages/appkit/src/agents/databricks.ts | 2 +- packages/appkit/src/beta.ts | 1 + packages/appkit/tsdown.config.ts | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/packages/appkit/package.json b/packages/appkit/package.json index f2f8e366e..bddc99a8b 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -33,10 +33,6 @@ "development": "./src/beta.ts", "default": "./dist/beta.js" }, - "./agents/databricks": { - "development": "./src/agents/databricks.ts", - "default": "./dist/agents/databricks.js" - }, "./type-generator": { "types": "./dist/type-generator/index.d.ts", "development": "./src/type-generator/index.ts", @@ -104,7 +100,6 @@ "exports": { ".": "./dist/index.js", "./beta": "./dist/beta.js", - "./agents/databricks": "./dist/agents/databricks.js", "./dist/shared/src/plugin": "./dist/shared/src/plugin.d.ts", "./type-generator": "./dist/type-generator/index.js", "./package.json": "./package.json" diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index fc3ce176f..d958d4023 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -157,7 +157,7 @@ interface DeltaToolCall { * @example Using the factory (recommended) * ```ts * import { createApp, createAgent, agents } from "@databricks/appkit"; - * import { DatabricksAdapter } from "@databricks/appkit/agents/databricks"; + * import { DatabricksAdapter } from "@databricks/appkit/beta"; * import { WorkspaceClient } from "@databricks/sdk-experimental"; * * const adapter = DatabricksAdapter.fromServingEndpoint({ diff --git a/packages/appkit/src/beta.ts b/packages/appkit/src/beta.ts index 57db86362..04e893bf3 100644 --- a/packages/appkit/src/beta.ts +++ b/packages/appkit/src/beta.ts @@ -4,4 +4,5 @@ // // The exports below are auto-generated from each plugin's manifest.json // "stability" field. See tools/generate-plugin-entries.ts. +export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks"; export * from "./plugins/beta-exports.generated"; diff --git a/packages/appkit/tsdown.config.ts b/packages/appkit/tsdown.config.ts index fdfdd721d..d61e8c534 100644 --- a/packages/appkit/tsdown.config.ts +++ b/packages/appkit/tsdown.config.ts @@ -4,7 +4,7 @@ export default defineConfig([ { publint: true, name: "@databricks/appkit", - entry: ["src/index.ts", "src/beta.ts", "src/agents/databricks.ts"], + entry: ["src/index.ts", "src/beta.ts"], outDir: "dist", hash: false, format: "esm", From c6a8bbb42a9e9c25610864cab995c5134573f987 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 5 May 2026 11:49:53 +0200 Subject: [PATCH 08/10] fix(appkit): error on DatabricksAdapter tool wire name collisions Dots map to double underscores for serving; distinct names can share the same wire string (e.g. foo.bar vs foo__bar). Throw instead of overwriting maps. Add regression test before any HTTP call. Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 5 +++ .../src/agents/tests/databricks.test.ts | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index d958d4023..26b70babf 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -336,6 +336,11 @@ export class DatabricksAdapter implements AgentAdapter { const wireToName = new Map(); for (const tool of input.tools) { const wire = tool.name.replace(/\./g, "__"); + if (wireToName.has(wire) && wireToName.get(wire) !== tool.name) { + throw new Error( + `Tool name collision: '${tool.name}' and '${wireToName.get(wire)}' both map to wire name '${wire}'`, + ); + } nameToWire.set(tool.name, wire); wireToName.set(wire, tool.name); } diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index b6a6b0a61..20a1b4cc0 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -153,6 +153,37 @@ describe("DatabricksAdapter", () => { expect(init.headers.Authorization).toBe("Bearer test-token"); }); + test("throws when two tool names map to the same wire format", async () => { + const adapter = createAdapter(); + const conflictingTools: AgentToolDefinition[] = [ + { + name: "foo.bar", + description: "one", + parameters: { type: "object", properties: {} }, + }, + { + name: "foo__bar", + description: "two", + parameters: { type: "object", properties: {} }, + }, + ]; + + await expect(async () => { + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: conflictingTools, + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow( + /Tool name collision: .* both map to wire name 'foo__bar'/, + ); + }); + test("handles structured tool calls and executes them", async () => { const executeTool = vi.fn().mockResolvedValue([{ trip_id: 1 }]); From 784918fc9c9ab422bb1c1144478f18cfbfd92945 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 5 May 2026 11:55:05 +0200 Subject: [PATCH 09/10] refactor(appkit): extract executeSingleTool from DatabricksAdapter run paths Shared parse / tool_call / execute / tool_result / message-append logic between structured tool_calls handling and text-parse executeToolCalls(). Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 125 +++++++++-------------- 1 file changed, 49 insertions(+), 76 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index 26b70babf..e3bf045d0 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -377,48 +377,58 @@ export class DatabricksAdapter implements AgentAdapter { for (const tc of toolCalls) { const wireName = tc.function.name; const originalName = wireToName.get(wireName) ?? wireName; - let args: unknown; - try { - args = JSON.parse(tc.function.arguments); - } catch { - args = {}; - } - - yield { type: "tool_call", callId: tc.id, name: originalName, args }; - - try { - const result = await context.executeTool(originalName, args); - const resultStr = - typeof result === "string" ? result : JSON.stringify(result); - - yield { type: "tool_result", callId: tc.id, result }; - - messages.push({ - role: "tool", - content: resultStr, - tool_call_id: tc.id, - }); - } catch (error) { - const errMsg = - error instanceof Error ? error.message : "Tool execution failed"; - - yield { - type: "tool_result", - callId: tc.id, - result: null, - error: errMsg, - }; - - messages.push({ - role: "tool", - content: JSON.stringify({ error: errMsg }), - tool_call_id: tc.id, - }); - } + yield* this.executeSingleTool(tc, originalName, messages, context); } } } + /** Parse wire arguments, emit tool_call / tool_result, append tool messages. */ + private async *executeSingleTool( + tc: OpenAIToolCall, + originalName: string, + messages: OpenAIMessage[], + context: AgentRunContext, + ): AsyncGenerator { + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name: originalName, args }; + + try { + const result = await context.executeTool(originalName, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + private async *streamCompletion( messages: OpenAIMessage[], tools: OpenAITool[], @@ -580,44 +590,7 @@ export class DatabricksAdapter implements AgentAdapter { for (let i = 0; i < toolCallObjs.length; i++) { const tc = toolCallObjs[i]; const originalName = calls[i]?.name ?? tc.function.name; - let args: unknown; - try { - args = JSON.parse(tc.function.arguments); - } catch { - args = {}; - } - - yield { type: "tool_call", callId: tc.id, name: originalName, args }; - - try { - const result = await context.executeTool(originalName, args); - const resultStr = - typeof result === "string" ? result : JSON.stringify(result); - - yield { type: "tool_result", callId: tc.id, result }; - - messages.push({ - role: "tool", - content: resultStr, - tool_call_id: tc.id, - }); - } catch (error) { - const errMsg = - error instanceof Error ? error.message : "Tool execution failed"; - - yield { - type: "tool_result", - callId: tc.id, - result: null, - error: errMsg, - }; - - messages.push({ - role: "tool", - content: JSON.stringify({ error: errMsg }), - tool_call_id: tc.id, - }); - } + yield* this.executeSingleTool(tc, originalName, messages, context); } } From 6dac3578e8f7f01ef7637064e752b130b54d25e5 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 5 May 2026 12:04:24 +0200 Subject: [PATCH 10/10] fix(appkit): harden DatabricksAdapter stream and parsers (review P2) - Yield status:error then rethrow when streamBody rejects - AbortSignal.timeout(120s) for raw fetch when no runner signal - Replace harmful Llama-array regex with indexOf/lastIndexOf slice - Cap Python-style text parser input (64KiB); narrow SSE JSON with unknown guards - console.debug malformed SSE JSON and reader teardown failures - JSDoc: executeTool errors may reach the LLM Includes regression tests. Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 170 ++++++++++++------ .../src/agents/tests/databricks.test.ts | 139 ++++++++++++++ packages/shared/src/agent.ts | 1 + 3 files changed, 257 insertions(+), 53 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index e3bf045d0..3b902c2ef 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -16,6 +16,39 @@ const DEFAULT_MAX_STREAM_TEXT_CHARS = 4 * 1024 * 1024; /** Default cap for accumulated JSON arguments per streamed tool call index. */ const DEFAULT_MAX_TOOL_ARGUMENT_CHARS = 2 * 1024 * 1024; +/** Cap text length before running Python-style tool-call regex (ReDoS guard). */ +const PYTHON_STYLE_TOOL_PARSE_MAX_INPUT = 64 * 1024; + +/** Fallback HTTP timeout when the raw fetch adapter path receives no AbortSignal from the runner. */ +const RAW_FETCH_DEFAULT_TIMEOUT_MS = 120_000; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function extractLlamaToolJsonSlice(text: string): string | undefined { + const start = text.indexOf("[{"); + if (start < 0) return undefined; + const endBracket = text.lastIndexOf("}]"); + if (endBracket < start) return undefined; + return text.slice(start, endBracket + 2); +} + +/** OpenAI SSE payload: `{ choices: [{ delta }] }`. */ +function openAiChoicesDelta(parsed: unknown): unknown { + if (!isRecord(parsed)) return undefined; + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length < 1) return undefined; + const first = choices[0]; + if (!isRecord(first)) return undefined; + return first.delta; +} + +function isStreamingDeltaToolCall(value: unknown): value is DeltaToolCall { + if (!isRecord(value)) return false; + return typeof value.index === "number"; +} + function throwIfExceedsStreamLimit( label: string, currentLength: number, @@ -210,6 +243,8 @@ export class DatabricksAdapter implements AgentAdapter { } else { const { endpointUrl, authenticate } = options; this.streamBody = async (body, signal) => { + const fetchSignal = + signal ?? AbortSignal.timeout(RAW_FETCH_DEFAULT_TIMEOUT_MS); const authHeaders = await authenticate(); const response = await fetch(endpointUrl, { method: "POST", @@ -218,7 +253,7 @@ export class DatabricksAdapter implements AgentAdapter { ...authHeaders, }, body: JSON.stringify(body), - signal, + signal: fetchSignal, }); if (!response.ok) { const errorText = await response.text().catch(() => "Unknown error"); @@ -448,7 +483,15 @@ export class DatabricksAdapter implements AgentAdapter { body.tools = tools; } - const responseBody = await this.streamBody(body, context.signal); + let responseBody: ReadableStream; + try { + responseBody = await this.streamBody(body, context.signal); + } catch (err) { + const msg = err instanceof Error ? err.message : "Stream request failed"; + yield { type: "status", status: "error", error: msg }; + throw err; + } + const reader = responseBody.getReader(); const decoder = new TextDecoder(); @@ -488,53 +531,61 @@ export class DatabricksAdapter implements AgentAdapter { const data = trimmed.slice(6); if (data === "[DONE]") continue; - let parsed: any; + let parsed: unknown; try { parsed = JSON.parse(data); - } catch { + } catch (parseErr) { + console.debug( + "[DatabricksAdapter] malformed SSE data line JSON", + { line: `${data.slice(0, 256)}${data.length > 256 ? "…" : ""}` }, + parseErr, + ); continue; } - const delta = parsed.choices?.[0]?.delta; - if (!delta) continue; + const deltaUnknown = openAiChoicesDelta(parsed); + if (!isRecord(deltaUnknown)) continue; - if (delta.content) { + if (typeof deltaUnknown.content === "string") { + const content = deltaUnknown.content; throwIfExceedsStreamLimit( "streamed assistant text", fullText.length, - delta.content, + content, this.maxStreamTextChars, ); - fullText += delta.content; - yield { type: "message_delta" as const, content: delta.content }; + fullText += content; + yield { type: "message_delta" as const, content }; } - if (delta.tool_calls) { - for (const tc of delta.tool_calls as DeltaToolCall[]) { - const existing = toolCallAccumulator.get(tc.index); - if (existing) { - if (tc.function?.arguments) { - throwIfExceedsStreamLimit( - "tool call arguments", - existing.arguments.length, - tc.function.arguments, - this.maxToolArgumentsChars, - ); - existing.arguments += tc.function.arguments; - } - } else { - const initial = tc.function?.arguments ?? ""; - if (initial.length > this.maxToolArgumentsChars) { - throw new Error( - `DatabricksAdapter: tool call arguments exceed configured limit (${this.maxToolArgumentsChars} UTF-16 code units)`, - ); - } - toolCallAccumulator.set(tc.index, { - id: tc.id ?? `call_${tc.index}`, - name: tc.function?.name ?? "", - arguments: initial, - }); + const toolCallsRaw = deltaUnknown.tool_calls; + if (!Array.isArray(toolCallsRaw)) continue; + + for (const tc of toolCallsRaw) { + if (!isStreamingDeltaToolCall(tc)) continue; + const existing = toolCallAccumulator.get(tc.index); + if (existing) { + if (tc.function?.arguments) { + throwIfExceedsStreamLimit( + "tool call arguments", + existing.arguments.length, + tc.function.arguments, + this.maxToolArgumentsChars, + ); + existing.arguments += tc.function.arguments; } + } else { + const initial = tc.function?.arguments ?? ""; + if (initial.length > this.maxToolArgumentsChars) { + throw new Error( + `DatabricksAdapter: tool call arguments exceed configured limit (${this.maxToolArgumentsChars} UTF-16 code units)`, + ); + } + toolCallAccumulator.set(tc.index, { + id: tc.id ?? `call_${tc.index}`, + name: tc.function?.name ?? "", + arguments: initial, + }); } } } @@ -542,13 +593,19 @@ export class DatabricksAdapter implements AgentAdapter { } finally { try { await reader.cancel(); - } catch { - // Best-effort: reader may already be closed or the stream errored. + } catch (cancelErr) { + console.debug( + "[DatabricksAdapter] reader.cancel() failed during teardown", + cancelErr, + ); } try { reader.releaseLock(); - } catch { - // Lock may already be released after cancel. + } catch (unlockErr) { + console.debug( + "[DatabricksAdapter] reader.releaseLock() failed during teardown", + unlockErr, + ); } } @@ -684,27 +741,30 @@ export function parseTextToolCalls( return []; } +function isLlamaToolJsonItem(value: unknown): value is Record< + string, + unknown +> & { + name: string; +} { + if (!isRecord(value)) return false; + return typeof value.name === "string"; +} + function tryParseLlamaJsonToolCalls( text: string, ): Array<{ name: string; args: unknown }> { - const match = text.match(/\[\s*\{[\s\S]*\}\s*\]/); - if (!match) return []; + const slice = extractLlamaToolJsonSlice(text); + if (!slice) return []; try { - const parsed = JSON.parse(match[0]); + const parsed: unknown = JSON.parse(slice); if (!Array.isArray(parsed)) return []; - return parsed - .filter( - (item: any) => - typeof item === "object" && - item !== null && - typeof item.name === "string", - ) - .map((item: any) => ({ - name: item.name, - args: item.parameters ?? item.arguments ?? item.args ?? {}, - })); + return parsed.filter(isLlamaToolJsonItem).map((item) => ({ + name: item.name, + args: item.parameters ?? item.arguments ?? item.args ?? {}, + })); } catch { return []; } @@ -713,6 +773,10 @@ function tryParseLlamaJsonToolCalls( function tryParsePythonStyleToolCalls( text: string, ): Array<{ name: string; args: unknown }> { + if (text.length > PYTHON_STYLE_TOOL_PARSE_MAX_INPUT) { + return []; + } + const pattern = /\[?([a-zA-Z_][\w.]*)\(([^)]*)\)\]?/g; const results: Array<{ name: string; args: unknown }> = []; diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 20a1b4cc0..fd51bc0fc 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -557,6 +557,138 @@ describe("DatabricksAdapter", () => { } }).rejects.toThrow("Databricks API error (401): Unauthorized"); }); + + test("yields error status then throws when injected streamBody fails", async () => { + const adapter = new DatabricksAdapter({ + streamBody: async () => Promise.reject(new Error("serving_unreachable")), + maxSteps: 1, + }); + + const events: AgentEvent[] = []; + await expect(async () => { + for await (const ev of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + }).rejects.toThrow("serving_unreachable"); + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ + type: "status", + status: "error", + error: "serving_unreachable", + }); + }); + + test("yields tool_result with error when executeTool rejects", async () => { + const executeTool = vi.fn().mockRejectedValue(new Error("tool_denied")); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_fail", + "analytics__query", + '{"query":"SELECT 2"}', + ), + sseChunk("[DONE]"), + ]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSteps: 1 }); + const events: AgentEvent[] = []; + + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(ev); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_fail", + name: "analytics.query", + args: { query: "SELECT 2" }, + }); + + expect(events).toContainEqual({ + type: "tool_result", + callId: "call_fail", + result: null, + error: "tool_denied", + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 2", + }); + }); + + test("uses AbortSignal.timeout for raw fetch when context has no signal", async () => { + globalThis.fetch = mockFetch([textDelta("Hello"), sseChunk("[DONE]")]); + + const ac = new AbortController(); + const timeoutSpy = vi + .spyOn(AbortSignal, "timeout") + .mockReturnValue(ac.signal); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn(), signal: undefined }, + )) { + // drain + } + + expect(timeoutSpy).toHaveBeenCalledWith(120_000); + timeoutSpy.mockRestore(); + }); + + test("logs and skips malformed JSON in SSE lines", async () => { + const debugSpy = vi.spyOn(console, "debug").mockImplementation(() => {}); + globalThis.fetch = mockFetch([ + sseChunk("{not-json-truncated"), + textDelta("ok"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + expect( + debugSpy.mock.calls.some(([msg]) => { + return typeof msg === "string" && msg.includes("malformed SSE"); + }), + ).toBe(true); + expect( + events.some((e) => e.type === "message_delta" && e.content === "ok"), + ).toBe(true); + debugSpy.mockRestore(); + }); }); describe("DatabricksAdapter.fromServingEndpoint", () => { @@ -740,4 +872,11 @@ describe("parseTextToolCalls", () => { { name: "lakebase.query", args: { text: "SELECT 1" } }, ]); }); + + test("returns empty when Python-style fallback text exceeds size cap", () => { + const cap = 64 * 1024; + const filler = "x".repeat(cap); + const suffix = "[analytics.query(query='SELECT 1')]"; + expect(parseTextToolCalls(`${filler}${suffix}`)).toEqual([]); + }); }); diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts index c4f76b294..ef532c7c7 100644 --- a/packages/shared/src/agent.ts +++ b/packages/shared/src/agent.ts @@ -200,6 +200,7 @@ export interface AgentInput { } export interface AgentRunContext { + /** Tool implementations should sanitize failure text — errors become `tool_result.error` and can flow back into the LLM transcript. */ executeTool: (name: string, args: unknown) => Promise; signal?: AbortSignal; }