diff --git a/src/orchestrator/batch.ts b/src/orchestrator/batch.ts index 239c6ae..5febfdd 100644 --- a/src/orchestrator/batch.ts +++ b/src/orchestrator/batch.ts @@ -35,6 +35,7 @@ export interface CompareOptions { judgeModel: string answeringModel: string sampling?: SamplingConfig + questionIds?: string[] force?: boolean } @@ -146,7 +147,7 @@ export class BatchManager { } async createManifest(options: CompareOptions): Promise { - const { providers, benchmark, judgeModel, answeringModel, sampling } = options + const { providers, benchmark, judgeModel, answeringModel, sampling, questionIds } = options const compareId = generateCompareId() logger.info(`Loading benchmark: ${benchmark}`) @@ -155,7 +156,37 @@ export class BatchManager { const allQuestions = benchmarkInstance.getQuestions() let targetQuestionIds: string[] - if (sampling) { + if (questionIds && questionIds.length > 0) { + // Validate that all provided IDs exist in the benchmark + const allQuestionIdsSet = new Set(allQuestions.map((q) => q.questionId)) + const validIds: string[] = [] + const invalidIds: string[] = [] + + for (const id of questionIds) { + if (allQuestionIdsSet.has(id)) { + validIds.push(id) + } else { + invalidIds.push(id) + } + } + + if (invalidIds.length > 0) { + logger.warn(`Invalid question IDs (will be skipped): ${invalidIds.join(", ")}`) + } + + if (validIds.length === 0) { + throw new Error( + `All provided questionIds are invalid. No matching questions found in benchmark "${benchmark}". ` + + `Invalid IDs: ${invalidIds.join(", ")}` + ) + } + + targetQuestionIds = validIds + logger.info( + `Using explicit questionIds: ${validIds.length} valid questions` + + (invalidIds.length > 0 ? ` (${invalidIds.length} invalid skipped)` : "") + ) + } else if (sampling) { targetQuestionIds = selectQuestionsBySampling(allQuestions, sampling) } else { targetQuestionIds = allQuestions.map((q) => q.questionId) diff --git a/src/orchestrator/index.ts b/src/orchestrator/index.ts index 64578bb..f19123b 100644 --- a/src/orchestrator/index.ts +++ b/src/orchestrator/index.ts @@ -213,8 +213,35 @@ export class Orchestrator { effectiveLimit = limit if (questionIds && questionIds.length > 0) { - logger.info(`Using explicit questionIds: ${questionIds.length} questions`) - targetQuestionIds = questionIds + // Validate that all provided IDs exist in the benchmark + const allQuestionIdsSet = new Set(allQuestions.map((q) => q.questionId)) + const validIds: string[] = [] + const invalidIds: string[] = [] + + for (const id of questionIds) { + if (allQuestionIdsSet.has(id)) { + validIds.push(id) + } else { + invalidIds.push(id) + } + } + + if (invalidIds.length > 0) { + logger.warn(`Invalid question IDs (will be skipped): ${invalidIds.join(", ")}`) + } + + if (validIds.length === 0) { + throw new Error( + `All provided questionIds are invalid. No matching questions found in benchmark "${benchmarkName}". ` + + `Invalid IDs: ${invalidIds.join(", ")}` + ) + } + + targetQuestionIds = validIds + logger.info( + `Using explicit questionIds: ${validIds.length} valid questions` + + (invalidIds.length > 0 ? ` (${invalidIds.length} invalid skipped)` : "") + ) } else if (sampling) { logger.info(`Using sampling mode: ${sampling.mode}`) targetQuestionIds = selectQuestionsBySampling(allQuestions, sampling) diff --git a/src/server/routes/benchmarks.ts b/src/server/routes/benchmarks.ts index af0b377..6860315 100644 --- a/src/server/routes/benchmarks.ts +++ b/src/server/routes/benchmarks.ts @@ -128,6 +128,75 @@ export async function handleBenchmarksRoutes(req: Request, url: URL): Promise() + const patternResults: Record = {} + + for (const pattern of patterns) { + const trimmed = pattern.trim() + if (!trimmed) continue + + const expanded: string[] = [] + + // Pattern 1: Conversation ID (e.g., "conv-26") - expand to all questions + // Check if pattern ends with a number and doesn't have -q or -session suffix + if (/^[a-zA-Z]+-\d+$/.test(trimmed)) { + const matchingQuestions = allQuestions.filter((q) => + q.questionId.startsWith(trimmed + "-q") + ) + matchingQuestions.forEach((q) => { + expanded.push(q.questionId) + expandedIds.add(q.questionId) + }) + } + // Pattern 2: Session ID (e.g., "conv-26-session_1" or "001be529-session-0") + // Find all questions that reference this session + else if (trimmed.includes("-session")) { + const matchingQuestions = allQuestions.filter((q) => + q.haystackSessionIds.includes(trimmed) + ) + matchingQuestions.forEach((q) => { + expanded.push(q.questionId) + expandedIds.add(q.questionId) + }) + } + // Pattern 3: Direct question ID - add as-is if it exists + else { + const exactMatch = allQuestions.find((q) => q.questionId === trimmed) + if (exactMatch) { + expanded.push(trimmed) + expandedIds.add(trimmed) + } + } + + patternResults[pattern] = expanded + } + + return json({ + expandedIds: Array.from(expandedIds), + patternResults, + }) + } catch (e) { + return json({ error: e instanceof Error ? e.message : "Failed to expand IDs" }, 400) + } + } + // GET /api/models - List available models if (method === "GET" && pathname === "/api/models") { const openai = listModelsByProvider("openai").map((alias) => ({ diff --git a/src/server/routes/compare.ts b/src/server/routes/compare.ts index c31589f..7c81f75 100644 --- a/src/server/routes/compare.ts +++ b/src/server/routes/compare.ts @@ -146,7 +146,8 @@ export async function handleCompareRoutes(req: Request, url: URL): Promise { // Only await manifest creation - this is fast diff --git a/src/server/routes/runs.ts b/src/server/routes/runs.ts index 1aaab7b..5002cfd 100644 --- a/src/server/routes/runs.ts +++ b/src/server/routes/runs.ts @@ -190,12 +190,14 @@ export async function handleRunsRoutes(req: Request, url: URL): Promise + } | null>(null) const compareIdInputRef = useRef(null) useEffect(() => { @@ -72,6 +83,70 @@ export default function NewComparePage() { } } + async function validateQuestionIds( + benchmark: string, + questionIdsInput: string + ): Promise<{ + valid: string[] + invalid: string[] + total: number + expanded: string[] + patternResults: Record + }> { + // Parse input: split by comma, trim, remove duplicates + const inputPatterns = questionIdsInput + .split(",") + .map((id) => id.trim()) + .filter((id) => id.length > 0) + const uniquePatterns = [...new Set(inputPatterns)] + + // Call pattern expansion endpoint + const expansionResult = await expandQuestionIdPatterns(benchmark, uniquePatterns) + const expandedIds = expansionResult.expandedIds + + // Fetch all questions to validate expanded IDs exist + const allQuestionIds = new Set() + let page = 1 + let hasMore = true + + while (hasMore) { + const response = await getBenchmarkQuestions(benchmark, { + page, + limit: 100, + }) + response.questions.forEach((q) => allQuestionIds.add(q.questionId)) + hasMore = page < response.pagination.totalPages + page++ + } + + // Validate expanded IDs + const valid: string[] = [] + const invalid: string[] = [] + + expandedIds.forEach((id) => { + if (allQuestionIds.has(id)) { + valid.push(id) + } else { + invalid.push(id) + } + }) + + // Find patterns that didn't expand to anything + const patternsWithNoResults = uniquePatterns.filter( + (pattern) => + !expansionResult.patternResults[pattern] || + expansionResult.patternResults[pattern].length === 0 + ) + + return { + valid, + invalid: [...invalid, ...patternsWithNoResults], + total: uniquePatterns.length, + expanded: expandedIds, + patternResults: expansionResult.patternResults, + } + } + function generateCompareId() { const now = new Date() const date = now.toISOString().slice(0, 10).replace(/-/g, "") @@ -93,6 +168,7 @@ export default function NewComparePage() { const compareId = form.compareId || generateCompareId() let sampling: SamplingConfig | undefined + let questionIds: string[] | undefined if (form.selectionMode === "full") { sampling = { mode: "full" } } else if (form.selectionMode === "sample") { @@ -107,6 +183,20 @@ export default function NewComparePage() { mode: "limit", limit: parseInt(form.limit), } + } else if (form.selectionMode === "questionIds") { + if (!form.questionIds.trim()) { + setError("Please enter at least one pattern or question ID") + return + } + + // Require validation before submission (MANDATORY) + if (!questionIdValidation || questionIdValidation.invalid.length > 0) { + setError("Please validate patterns before starting the comparison") + return + } + + // Use the expanded question IDs from validation + questionIds = questionIdValidation.expanded } try { @@ -120,6 +210,7 @@ export default function NewComparePage() { judgeModel: form.judgeModel, answeringModel: form.answeringModel, sampling, + questionIds, }) router.push(`/compare`) @@ -273,14 +364,22 @@ export default function NewComparePage() { Question Selection
- {(["full", "sample", "limit"] as SelectionMode[]).map((mode) => { + {(["full", "sample", "limit", "questionIds"] as SelectionMode[]).map((mode) => { const isSelected = form.selectionMode === mode - const labels = { full: "Full", sample: "Sample", limit: "Limit" } + const labels = { + full: "Full", + sample: "Sample", + limit: "Limit", + questionIds: "IDs", + } return (
)} + + {form.selectionMode === "questionIds" && ( +
+
+ +