diff --git a/.changeset/nine-zoos-buy.md b/.changeset/nine-zoos-buy.md new file mode 100644 index 000000000..d27a99422 --- /dev/null +++ b/.changeset/nine-zoos-buy.md @@ -0,0 +1,5 @@ +--- +'@hyperdx/api': patch +--- + +feat: Add OpenAI provider support for AI assistance diff --git a/packages/api/package.json b/packages/api/package.json index 740fac1df..50b376493 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -8,6 +8,7 @@ }, "dependencies": { "@ai-sdk/anthropic": "^3.0.58", + "@ai-sdk/openai": "^3.0.47", "@esm2cjs/p-queue": "^7.3.0", "@hyperdx/common-utils": "^0.16.1", "@hyperdx/node-opentelemetry": "^0.9.0", diff --git a/packages/api/src/config.ts b/packages/api/src/config.ts index a0c85135d..6ba78de6c 100644 --- a/packages/api/src/config.ts +++ b/packages/api/src/config.ts @@ -53,6 +53,7 @@ export const AI_PROVIDER = env.AI_PROVIDER as string; // 'anthropic' | 'openai' export const AI_API_KEY = env.AI_API_KEY as string; export const AI_BASE_URL = env.AI_BASE_URL as string; export const AI_MODEL_NAME = env.AI_MODEL_NAME as string; +export const AI_REQUEST_HEADERS = env.AI_REQUEST_HEADERS as string; // Legacy Anthropic-specific configuration (backward compatibility) export const ANTHROPIC_API_KEY = env.ANTHROPIC_API_KEY as string; diff --git a/packages/api/src/controllers/__tests__/ai.test.ts b/packages/api/src/controllers/__tests__/ai.test.ts new file mode 100644 index 000000000..a2e9e4e0a --- /dev/null +++ b/packages/api/src/controllers/__tests__/ai.test.ts @@ -0,0 +1,252 @@ +import type { LanguageModel } from 'ai'; + +const mockAnthropicModel = { + modelId: 'claude-sonnet-4-5-20250929', +} as unknown as LanguageModel; + +const mockOpenAIModel = { + modelId: 'gpt-4o', +} as unknown as LanguageModel; + +const mockAnthropicFactory = jest.fn((_model?: string) => mockAnthropicModel); +const mockCreateAnthropic = jest.fn( + (_opts?: Record) => mockAnthropicFactory, +); + +const mockOpenAIChatFactory = jest.fn((_model?: string) => mockOpenAIModel); +const mockCreateOpenAI = jest.fn((_opts?: Record) => ({ + chat: mockOpenAIChatFactory, +})); + +jest.mock('@ai-sdk/anthropic', () => ({ + createAnthropic: (opts: Record) => mockCreateAnthropic(opts), +})); + +jest.mock('@ai-sdk/openai', () => ({ + createOpenAI: (opts: Record) => mockCreateOpenAI(opts), +})); + +jest.mock('@/utils/logger', () => ({ + __esModule: true, + default: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, +})); + +const mockConfig: Record = { __esModule: true }; + +jest.mock('@/config', () => mockConfig); + +function setConfig(overrides: Record) { + Object.keys(mockConfig).forEach(k => { + if (k !== '__esModule') delete mockConfig[k]; + }); + Object.assign(mockConfig, overrides); +} + +import { getAIModel } from '@/controllers/ai'; + +beforeEach(() => { + setConfig({}); + jest.clearAllMocks(); +}); + +describe('getAIModel', () => { + describe('provider routing', () => { + it('throws when no provider is configured', () => { + expect(() => getAIModel()).toThrow( + 'No AI provider configured. Set AI_PROVIDER and AI_API_KEY environment variables.', + ); + }); + + it('throws on unknown provider', () => { + setConfig({ AI_PROVIDER: 'gemini' }); + expect(() => getAIModel()).toThrow( + 'Unknown AI provider: gemini. Currently supported: anthropic, openai', + ); + }); + + it('routes to anthropic when AI_PROVIDER=anthropic', () => { + setConfig({ + AI_PROVIDER: 'anthropic', + AI_API_KEY: 'sk-test', + }); + const model = getAIModel(); + expect(model).toBe(mockAnthropicModel); + expect(mockCreateAnthropic).toHaveBeenCalledTimes(1); + }); + + it('routes to openai when AI_PROVIDER=openai', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + }); + const model = getAIModel(); + expect(model).toBe(mockOpenAIModel); + expect(mockCreateOpenAI).toHaveBeenCalledTimes(1); + }); + }); + + describe('legacy anthropic support', () => { + it('falls back to anthropic when ANTHROPIC_API_KEY is set without AI_PROVIDER', () => { + setConfig({ + ANTHROPIC_API_KEY: 'sk-ant-legacy', + }); + const model = getAIModel(); + expect(model).toBe(mockAnthropicModel); + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: 'sk-ant-legacy' }), + ); + }); + }); +}); + +describe('anthropic provider', () => { + it('throws when no API key is set', () => { + setConfig({ AI_PROVIDER: 'anthropic' }); + expect(() => getAIModel()).toThrow( + 'No API key defined for Anthropic. Set AI_API_KEY or ANTHROPIC_API_KEY.', + ); + }); + + it('uses AI_API_KEY over ANTHROPIC_API_KEY', () => { + setConfig({ + AI_PROVIDER: 'anthropic', + AI_API_KEY: 'sk-new', + ANTHROPIC_API_KEY: 'sk-old', + }); + getAIModel(); + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: 'sk-new' }), + ); + }); + + it('passes baseURL when AI_BASE_URL is set', () => { + setConfig({ + AI_PROVIDER: 'anthropic', + AI_API_KEY: 'sk-test', + AI_BASE_URL: 'https://custom.endpoint.com', + }); + getAIModel(); + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'sk-test', + baseURL: 'https://custom.endpoint.com', + }), + ); + }); + + it('uses default model when AI_MODEL_NAME is not set', () => { + setConfig({ + AI_PROVIDER: 'anthropic', + AI_API_KEY: 'sk-test', + }); + getAIModel(); + expect(mockAnthropicFactory).toHaveBeenCalledWith( + 'claude-sonnet-4-5-20250929', + ); + }); + + it('uses custom model name when AI_MODEL_NAME is set', () => { + setConfig({ + AI_PROVIDER: 'anthropic', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'claude-3-haiku-20240307', + }); + getAIModel(); + expect(mockAnthropicFactory).toHaveBeenCalledWith( + 'claude-3-haiku-20240307', + ); + }); +}); + +describe('openai provider', () => { + it('throws when no API key is set', () => { + setConfig({ AI_PROVIDER: 'openai' }); + expect(() => getAIModel()).toThrow( + 'No API key defined for OpenAI provider. Set AI_API_KEY.', + ); + }); + + it('throws when no model name is set', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + }); + expect(() => getAIModel()).toThrow( + 'No model name configured for OpenAI provider. Set AI_MODEL_NAME', + ); + }); + + it('creates provider with minimal config', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + }); + getAIModel(); + expect(mockCreateOpenAI).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: 'sk-test' }), + ); + expect(mockOpenAIChatFactory).toHaveBeenCalledWith('gpt-4o'); + }); + + it('passes baseURL when AI_BASE_URL is set', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + AI_BASE_URL: 'https://proxy.example.com/v1', + }); + getAIModel(); + expect(mockCreateOpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'sk-test', + baseURL: 'https://proxy.example.com/v1', + }), + ); + }); + + describe('AI_REQUEST_HEADERS', () => { + it('passes parsed headers to createOpenAI', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + AI_REQUEST_HEADERS: '{"X-Custom":"val1","X-Other":"val2"}', + }); + getAIModel(); + expect(mockCreateOpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + headers: { 'X-Custom': 'val1', 'X-Other': 'val2' }, + }), + ); + }); + + it('throws when AI_REQUEST_HEADERS is invalid JSON', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + AI_REQUEST_HEADERS: '{bad', + }); + expect(() => getAIModel()).toThrow( + 'AI_REQUEST_HEADERS is not valid JSON', + ); + }); + + it('omits headers when AI_REQUEST_HEADERS is not set', () => { + setConfig({ + AI_PROVIDER: 'openai', + AI_API_KEY: 'sk-test', + AI_MODEL_NAME: 'gpt-4o', + }); + getAIModel(); + const call = mockCreateOpenAI.mock.calls[0]?.[0]; + expect(call?.headers).toBeUndefined(); + }); + }); +}); diff --git a/packages/api/src/controllers/ai.ts b/packages/api/src/controllers/ai.ts index 2dced8623..aa2c4680c 100644 --- a/packages/api/src/controllers/ai.ts +++ b/packages/api/src/controllers/ai.ts @@ -1,4 +1,5 @@ import { createAnthropic } from '@ai-sdk/anthropic'; +import { createOpenAI } from '@ai-sdk/openai'; import { ClickhouseClient } from '@hyperdx/common-utils/dist/clickhouse/node'; import { getMetadata, @@ -16,6 +17,7 @@ import z from 'zod'; import * as config from '@/config'; import { ISource } from '@/models/source'; +import { parseJSON } from '@/utils/common'; import { Api500Error } from '@/utils/errors'; import logger from '@/utils/logger'; @@ -60,14 +62,11 @@ export function getAIModel(): LanguageModel { return getAnthropicModel(); case 'openai': - throw new Error( - `Provider '${provider}' is not yet supported. Currently only 'anthropic' is available. ` + - 'Support for additional providers can be added in the future.', - ); + return getOpenAIModel(); default: throw new Error( - `Unknown AI provider: ${provider}. Currently supported: anthropic`, + `Unknown AI provider: ${provider}. Currently supported: anthropic, openai`, ); } } @@ -367,3 +366,38 @@ function getAnthropicModel(): LanguageModel { return anthropic(modelName); } + +/** + * Configure OpenAI-compatible model. + * Works with any OpenAI Chat Completions-compatible endpoint + * (e.g. Azure OpenAI, OpenRouter, LiteLLM proxies). + */ +function getOpenAIModel(): LanguageModel { + const apiKey = config.AI_API_KEY; + + if (!apiKey) { + throw new Error('No API key defined for OpenAI provider. Set AI_API_KEY.'); + } + + if (!config.AI_MODEL_NAME) { + throw new Error( + 'No model name configured for OpenAI provider. Set AI_MODEL_NAME ' + + '(e.g. "gpt-4o", "claude-sonnet-4-5-20250929" for LiteLLM proxies).', + ); + } + + const headers: Record = config.AI_REQUEST_HEADERS + ? parseJSON>( + config.AI_REQUEST_HEADERS, + 'AI_REQUEST_HEADERS', + ) + : {}; + + const openai = createOpenAI({ + apiKey, + ...(config.AI_BASE_URL && { baseURL: config.AI_BASE_URL }), + ...(Object.keys(headers).length > 0 && { headers }), + }); + + return openai.chat(config.AI_MODEL_NAME); +} diff --git a/packages/api/src/routers/api/ai.ts b/packages/api/src/routers/api/ai.ts index 2c748ccf2..b8070c3a0 100644 --- a/packages/api/src/routers/api/ai.ts +++ b/packages/api/src/routers/api/ai.ts @@ -110,7 +110,9 @@ ${JSON.stringify(allFieldsWithKeys.slice(0, 200).map(f => ({ field: f.key, type: return res.json(chartConfig); } catch (err) { if (err instanceof APICallError) { - throw new Api500Error(`AI Provider Error: ${err.message}`); + throw new Api500Error( + `AI Provider Error. Status: ${err.statusCode}. Message: ${err.message}`, + ); } throw err; } diff --git a/packages/api/src/utils/common.ts b/packages/api/src/utils/common.ts index 7cde89296..882d2f343 100644 --- a/packages/api/src/utils/common.ts +++ b/packages/api/src/utils/common.ts @@ -27,6 +27,14 @@ export const tryJSONStringify = (json: Json) => { return result; }; +export function parseJSON(raw: string, label: string): T { + try { + return JSON.parse(raw) as T; + } catch (e) { + throw new Error(`${label} is not valid JSON: ${(e as Error).message}`); + } +} + export const truncateString = (str: string, length: number) => { if (str.length > length) { return str.substring(0, length) + '...'; diff --git a/yarn.lock b/yarn.lock index 372f7119f..66180c355 100644 --- a/yarn.lock +++ b/yarn.lock @@ -44,6 +44,18 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/openai@npm:^3.0.47": + version: 3.0.48 + resolution: "@ai-sdk/openai@npm:3.0.48" + dependencies: + "@ai-sdk/provider": "npm:3.0.8" + "@ai-sdk/provider-utils": "npm:4.0.21" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/2aaf85fba5ec98e892a41b276a6d2dff9f728f1f149293f437f15daae06a894886730cc9a6d9be39f3f8ba1679d07635c9e9dac5b0013ac3549658a4ffad6638 + languageName: node + linkType: hard + "@ai-sdk/provider-utils@npm:4.0.20": version: 4.0.20 resolution: "@ai-sdk/provider-utils@npm:4.0.20" @@ -57,6 +69,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:4.0.21": + version: 4.0.21 + resolution: "@ai-sdk/provider-utils@npm:4.0.21" + dependencies: + "@ai-sdk/provider": "npm:3.0.8" + "@standard-schema/spec": "npm:^1.1.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/70d19cfefad32865f098d0e6e53342408363929d30151eeb072f90c2d7661b24e4a9bd9ac39d75352aeeee8dffeb931c49b85cafcd953c67336c895497c71cfa + languageName: node + linkType: hard + "@ai-sdk/provider@npm:3.0.8": version: 3.0.8 resolution: "@ai-sdk/provider@npm:3.0.8" @@ -4264,6 +4289,7 @@ __metadata: resolution: "@hyperdx/api@workspace:packages/api" dependencies: "@ai-sdk/anthropic": "npm:^3.0.58" + "@ai-sdk/openai": "npm:^3.0.47" "@esm2cjs/p-queue": "npm:^7.3.0" "@hyperdx/common-utils": "npm:^0.16.1" "@hyperdx/node-opentelemetry": "npm:^0.9.0"