diff --git a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts index 10a919a3..4cc12ca2 100644 --- a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts +++ b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts @@ -77,7 +77,7 @@ import type { } from "@workglow/ai"; import { CallbackStatus } from "./HFT_CallbackStatus"; import { HTF_CACHE_NAME } from "./HFT_Constants"; -import { HfTransformersOnnxModelRecord } from "./HFT_ModelSchema"; +import { HfTransformersOnnxModelConfig } from "./HFT_ModelSchema"; const pipelines = new Map(); @@ -93,12 +93,12 @@ export function clearPipelineCache(): void { * @param progressScaleMax - Maximum progress value for download phase (100 for download-only, 10 for download+run) */ const getPipeline = async ( - model: HfTransformersOnnxModelRecord, + model: HfTransformersOnnxModelConfig, onProgress: (progress: number, message?: string, details?: any) => void, options: PretrainedModelOptions = {}, progressScaleMax: number = 10 ) => { - const cacheKey = `${model.model_id}:${model.providerConfig.pipeline}`; + const cacheKey = `${model.providerConfig.modelPath}:${model.providerConfig.pipeline}`; if (pipelines.has(cacheKey)) { return pipelines.get(cacheKey); } @@ -433,7 +433,7 @@ const getPipeline = async ( export const HFT_Download: AiProviderRunFn< DownloadModelTaskExecuteInput, DownloadModelTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Download the model by creating a pipeline // Use 100 as progressScaleMax since this is download-only (0-100%) @@ -451,11 +451,12 @@ export const HFT_Download: AiProviderRunFn< export const HFT_Unload: AiProviderRunFn< UnloadModelTaskExecuteInput, UnloadModelTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Delete the pipeline from the in-memory map - if (pipelines.has(model!.model_id)) { - pipelines.delete(model!.model_id); + const cacheKey = `${model!.providerConfig.modelPath}:${model!.providerConfig.pipeline}`; + if (pipelines.has(cacheKey)) { + pipelines.delete(cacheKey); onProgress(50, "Pipeline removed from memory"); } @@ -515,7 +516,7 @@ const deleteModelCache = async (modelPath: string): Promise => { export const HFT_TextEmbedding: AiProviderRunFn< TextEmbeddingTaskExecuteInput, TextEmbeddingTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateEmbedding: FeatureExtractionPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -546,7 +547,7 @@ export const HFT_TextEmbedding: AiProviderRunFn< export const HFT_TextClassification: AiProviderRunFn< TextClassificationTaskExecuteInput, TextClassificationTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.providerConfig?.pipeline === "zero-shot-classification") { if ( @@ -602,7 +603,7 @@ export const HFT_TextClassification: AiProviderRunFn< export const HFT_TextLanguageDetection: AiProviderRunFn< TextLanguageDetectionTaskExecuteInput, TextLanguageDetectionTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const TextClassification: TextClassificationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -632,7 +633,7 @@ export const HFT_TextLanguageDetection: AiProviderRunFn< export const HFT_TextNamedEntityRecognition: AiProviderRunFn< TextNamedEntityRecognitionTaskExecuteInput, TextNamedEntityRecognitionTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const textNamedEntityRecognition: TokenClassificationPipeline = await getPipeline( model!, @@ -663,7 +664,7 @@ export const HFT_TextNamedEntityRecognition: AiProviderRunFn< export const HFT_TextFillMask: AiProviderRunFn< TextFillMaskTaskExecuteInput, TextFillMaskTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const unmasker: FillMaskPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -691,7 +692,7 @@ export const HFT_TextFillMask: AiProviderRunFn< export const HFT_TextGeneration: AiProviderRunFn< TextGenerationTaskExecuteInput, TextGenerationTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -724,7 +725,7 @@ export const HFT_TextGeneration: AiProviderRunFn< export const HFT_TextTranslation: AiProviderRunFn< TextTranslationTaskExecuteInput, TextTranslationTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const translate: TranslationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -758,7 +759,7 @@ export const HFT_TextTranslation: AiProviderRunFn< export const HFT_TextRewriter: AiProviderRunFn< TextRewriterTaskExecuteInput, TextRewriterTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -798,7 +799,7 @@ export const HFT_TextRewriter: AiProviderRunFn< export const HFT_TextSummary: AiProviderRunFn< TextSummaryTaskExecuteInput, TextSummaryTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateSummary: SummarizationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -829,7 +830,7 @@ export const HFT_TextSummary: AiProviderRunFn< export const HFT_TextQuestionAnswer: AiProviderRunFn< TextQuestionAnswerTaskExecuteInput, TextQuestionAnswerTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Get the question answering pipeline const generateAnswer: QuestionAnsweringPipeline = await getPipeline(model!, onProgress, { @@ -860,7 +861,7 @@ export const HFT_TextQuestionAnswer: AiProviderRunFn< export const HFT_ImageSegmentation: AiProviderRunFn< ImageSegmentationTaskExecuteInput, ImageSegmentationTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const segmenter: ImageSegmentationPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -893,7 +894,7 @@ export const HFT_ImageSegmentation: AiProviderRunFn< export const HFT_ImageToText: AiProviderRunFn< ImageToTextTaskExecuteInput, ImageToTextTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const captioner: ImageToTextPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -917,7 +918,7 @@ export const HFT_ImageToText: AiProviderRunFn< export const HFT_BackgroundRemoval: AiProviderRunFn< BackgroundRemovalTaskExecuteInput, BackgroundRemovalTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const remover: BackgroundRemovalPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -940,7 +941,7 @@ export const HFT_BackgroundRemoval: AiProviderRunFn< export const HFT_ImageEmbedding: AiProviderRunFn< ImageEmbeddingTaskExecuteInput, ImageEmbeddingTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const embedder: ImageFeatureExtractionPipeline = await getPipeline(model!, onProgress, { abort_signal: signal, @@ -960,7 +961,7 @@ export const HFT_ImageEmbedding: AiProviderRunFn< export const HFT_ImageClassification: AiProviderRunFn< ImageClassificationTaskExecuteInput, ImageClassificationTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.providerConfig?.pipeline === "zero-shot-image-classification") { if (!input.categories || !Array.isArray(input.categories) || input.categories.length === 0) { @@ -1015,7 +1016,7 @@ export const HFT_ImageClassification: AiProviderRunFn< export const HFT_ObjectDetection: AiProviderRunFn< ObjectDetectionTaskExecuteInput, ObjectDetectionTaskExecuteOutput, - HfTransformersOnnxModelRecord + HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.providerConfig?.pipeline === "zero-shot-object-detection") { if (!input.labels || !Array.isArray(input.labels) || input.labels.length === 0) { @@ -1070,7 +1071,6 @@ function imageToBase64(image: RawImage): string { return (image as any).toBase64?.() || ""; } - /** * Create a text streamer for a given tokenizer and update progress function * @param tokenizer - The tokenizer to use for the streamer diff --git a/packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts b/packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts index 9798b505..12022df1 100644 --- a/packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts +++ b/packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelSchema } from "@workglow/ai"; +import { ModelConfigSchema, ModelRecordSchema } from "@workglow/ai"; import { DataPortSchemaObject, FromSchema } from "@workglow/util"; import { HF_TRANSFORMERS_ONNX, PipelineUseCase, QuantizationDataType } from "./HFT_Constants"; @@ -89,14 +89,26 @@ export const HfTransformersOnnxModelSchema = { additionalProperties: true, } as const satisfies DataPortSchemaObject; -const ExtendedModelSchema = { +const ExtendedModelRecordSchema = { type: "object", properties: { - ...ModelSchema.properties, + ...ModelRecordSchema.properties, ...HfTransformersOnnxModelSchema.properties, }, - required: [...ModelSchema.required, ...HfTransformersOnnxModelSchema.required], + required: [...ModelRecordSchema.required, ...HfTransformersOnnxModelSchema.required], additionalProperties: false, } as const satisfies DataPortSchemaObject; -export type HfTransformersOnnxModelRecord = FromSchema; +export type HfTransformersOnnxModelRecord = FromSchema; + +const ExtendedModelConfigSchema = { + type: "object", + properties: { + ...ModelConfigSchema.properties, + ...HfTransformersOnnxModelSchema.properties, + }, + required: [...ModelConfigSchema.required, ...HfTransformersOnnxModelSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +export type HfTransformersOnnxModelConfig = FromSchema; diff --git a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts index e1ce6f21..cb7bfd5b 100644 --- a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts +++ b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts @@ -53,7 +53,7 @@ import type { UnloadModelTaskExecuteOutput, } from "@workglow/ai"; import { PermanentJobError } from "@workglow/job-queue"; -import { TFMPModelRecord } from "./TFMP_ModelSchema"; +import { TFMPModelConfig } from "./TFMP_ModelSchema"; interface TFMPWasmFileset { /** The path to the Wasm loader script. */ @@ -82,7 +82,7 @@ const wasm_reference_counts = new Map(); * Helper function to get a WASM task for a model */ const getWasmTask = async ( - model: TFMPModelRecord, + model: TFMPModelConfig, onProgress: (progress: number, message?: string, details?: any) => void, signal: AbortSignal ): Promise => { @@ -213,7 +213,7 @@ const optionsMatch = (opts1: Record, opts2: Record( - model: TFMPModelRecord, + model: TFMPModelConfig, options: Record, onProgress: (progress: number, message?: string, details?: any) => void, signal: AbortSignal, @@ -264,7 +264,7 @@ const getModelTask = async ( export const TFMP_Download: AiProviderRunFn< DownloadModelTaskExecuteInput, DownloadModelTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { let task: TextEmbedder | TextClassifier | LanguageDetector; switch (model?.providerConfig.pipeline) { @@ -298,7 +298,7 @@ export const TFMP_Download: AiProviderRunFn< export const TFMP_TextEmbedding: AiProviderRunFn< TextEmbeddingTaskExecuteInput, TextEmbeddingTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const textEmbedder = await getModelTask(model!, {}, onProgress, signal, TextEmbedder); const result = textEmbedder.embed(input.text); @@ -321,7 +321,7 @@ export const TFMP_TextEmbedding: AiProviderRunFn< export const TFMP_TextClassification: AiProviderRunFn< TextClassificationTaskExecuteInput, TextClassificationTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const TextClassification = await getModelTask( model!, @@ -358,7 +358,7 @@ export const TFMP_TextClassification: AiProviderRunFn< export const TFMP_TextLanguageDetection: AiProviderRunFn< TextLanguageDetectionTaskExecuteInput, TextLanguageDetectionTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const maxLanguages = input.maxLanguages === 0 ? -1 : input.maxLanguages; @@ -402,7 +402,7 @@ export const TFMP_TextLanguageDetection: AiProviderRunFn< export const TFMP_Unload: AiProviderRunFn< UnloadModelTaskExecuteInput, UnloadModelTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const modelPath = model!.providerConfig.modelPath; onProgress(10, "Unloading model"); @@ -442,7 +442,7 @@ export const TFMP_Unload: AiProviderRunFn< export const TFMP_ImageSegmentation: AiProviderRunFn< ImageSegmentationTaskExecuteInput, ImageSegmentationTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageSegmenter = await getModelTask(model!, {}, onProgress, signal, ImageSegmenter); const result = imageSegmenter.segment(input.image as any); @@ -475,7 +475,7 @@ export const TFMP_ImageSegmentation: AiProviderRunFn< export const TFMP_ImageEmbedding: AiProviderRunFn< ImageEmbeddingTaskExecuteInput, ImageEmbeddingTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageEmbedder = await getModelTask(model!, {}, onProgress, signal, ImageEmbedder); const result = imageEmbedder.embed(input.image as any); @@ -497,7 +497,7 @@ export const TFMP_ImageEmbedding: AiProviderRunFn< export const TFMP_ImageClassification: AiProviderRunFn< ImageClassificationTaskExecuteInput, ImageClassificationTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageClassifier = await getModelTask( model!, @@ -530,7 +530,7 @@ export const TFMP_ImageClassification: AiProviderRunFn< export const TFMP_ObjectDetection: AiProviderRunFn< ObjectDetectionTaskExecuteInput, ObjectDetectionTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const objectDetector = await getModelTask( model!, @@ -569,7 +569,7 @@ export const TFMP_ObjectDetection: AiProviderRunFn< export const TFMP_GestureRecognizer: AiProviderRunFn< GestureRecognizerTaskExecuteInput, GestureRecognizerTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const gestureRecognizer = await getModelTask( model!, @@ -621,7 +621,7 @@ export const TFMP_GestureRecognizer: AiProviderRunFn< export const TFMP_HandLandmarker: AiProviderRunFn< HandLandmarkerTaskExecuteInput, HandLandmarkerTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const handLandmarker = await getModelTask( model!, @@ -669,7 +669,7 @@ export const TFMP_HandLandmarker: AiProviderRunFn< export const TFMP_FaceDetector: AiProviderRunFn< FaceDetectorTaskExecuteInput, FaceDetectorTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceDetector = await getModelTask( model!, @@ -714,7 +714,7 @@ export const TFMP_FaceDetector: AiProviderRunFn< export const TFMP_FaceLandmarker: AiProviderRunFn< FaceLandmarkerTaskExecuteInput, FaceLandmarkerTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceLandmarker = await getModelTask( model!, @@ -770,7 +770,7 @@ export const TFMP_FaceLandmarker: AiProviderRunFn< export const TFMP_PoseLandmarker: AiProviderRunFn< PoseLandmarkerTaskExecuteInput, PoseLandmarkerTaskExecuteOutput, - TFMPModelRecord + TFMPModelConfig > = async (input, model, onProgress, signal) => { const poseLandmarker = await getModelTask( model!, diff --git a/packages/ai-provider/src/tf-mediapipe/common/TFMP_ModelSchema.ts b/packages/ai-provider/src/tf-mediapipe/common/TFMP_ModelSchema.ts index f142ee0c..f90a229e 100644 --- a/packages/ai-provider/src/tf-mediapipe/common/TFMP_ModelSchema.ts +++ b/packages/ai-provider/src/tf-mediapipe/common/TFMP_ModelSchema.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelSchema } from "@workglow/ai"; +import { ModelConfigSchema, ModelRecordSchema } from "@workglow/ai"; import { DataPortSchemaObject, FromSchema } from "@workglow/util"; import { TENSORFLOW_MEDIAPIPE, TextPipelineTask } from "../common/TFMP_Constants"; @@ -42,14 +42,26 @@ export const TFMPModelSchema = { additionalProperties: true, } as const satisfies DataPortSchemaObject; -const ExtendedModelSchema = { +const ExtendedModelRecordSchema = { type: "object", properties: { - ...ModelSchema.properties, + ...ModelRecordSchema.properties, ...TFMPModelSchema.properties, }, - required: [...ModelSchema.required, ...TFMPModelSchema.required], + required: [...ModelRecordSchema.required, ...TFMPModelSchema.required], additionalProperties: false, } as const satisfies DataPortSchemaObject; -export type TFMPModelRecord = FromSchema; +export type TFMPModelRecord = FromSchema; + +const ExtendedModelConfigSchema = { + type: "object", + properties: { + ...ModelConfigSchema.properties, + ...TFMPModelSchema.properties, + }, + required: [...ModelConfigSchema.required, ...TFMPModelSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +export type TFMPModelConfig = FromSchema; diff --git a/packages/ai/src/job/AiJob.ts b/packages/ai/src/job/AiJob.ts index e5a2cc34..3bbe8aae 100644 --- a/packages/ai/src/job/AiJob.ts +++ b/packages/ai/src/job/AiJob.ts @@ -12,7 +12,7 @@ import { PermanentJobError, } from "@workglow/job-queue"; import { TaskInput, TaskOutput } from "@workglow/task-graph"; -import { getGlobalModelRepository } from "../model/ModelRegistry"; +import type { ModelConfig } from "../model/ModelSchema"; import { getAiProviderRegistry } from "../provider/AiProviderRegistry"; /** @@ -21,7 +21,7 @@ import { getAiProviderRegistry } from "../provider/AiProviderRegistry"; export interface AiJobInput { taskType: string; aiProvider: string; - taskInput: Input & { model: string }; + taskInput: Input & { model: ModelConfig }; } /** @@ -62,11 +62,7 @@ export class AiJob< `No run function found for task type ${input.taskType} and model provider ${input.aiProvider}` ); } - const modelName = input.taskInput.model; - const model = await getGlobalModelRepository().findByName(modelName); - if (modelName && !model) { - throw new PermanentJobError(`Model ${modelName} not found`); - } + const model = input.taskInput.model; if (context.signal?.aborted) { throw new AbortSignalJobError("Job aborted"); } diff --git a/packages/ai/src/model/InMemoryModelRepository.ts b/packages/ai/src/model/InMemoryModelRepository.ts index fbbf5afb..72fc78a9 100644 --- a/packages/ai/src/model/InMemoryModelRepository.ts +++ b/packages/ai/src/model/InMemoryModelRepository.ts @@ -6,7 +6,7 @@ import { InMemoryTabularRepository } from "@workglow/storage"; import { ModelRepository } from "./ModelRepository"; -import { ModelPrimaryKeyNames, ModelSchema } from "./ModelSchema"; +import { ModelPrimaryKeyNames, ModelRecordSchema } from "./ModelSchema"; /** * In-memory implementation of a model repository. @@ -14,6 +14,6 @@ import { ModelPrimaryKeyNames, ModelSchema } from "./ModelSchema"; */ export class InMemoryModelRepository extends ModelRepository { constructor() { - super(new InMemoryTabularRepository(ModelSchema, ModelPrimaryKeyNames)); + super(new InMemoryTabularRepository(ModelRecordSchema, ModelPrimaryKeyNames)); } } diff --git a/packages/ai/src/model/ModelRepository.ts b/packages/ai/src/model/ModelRepository.ts index 8469a246..8818273b 100644 --- a/packages/ai/src/model/ModelRepository.ts +++ b/packages/ai/src/model/ModelRepository.ts @@ -7,7 +7,7 @@ import { type TabularRepository } from "@workglow/storage"; import { EventEmitter, EventParameters } from "@workglow/util"; -import { ModelPrimaryKeyNames, ModelRecord, ModelSchema } from "./ModelSchema"; +import { ModelPrimaryKeyNames, ModelRecord, ModelRecordSchema } from "./ModelSchema"; /** * Events that can be emitted by the ModelRepository @@ -38,11 +38,11 @@ export class ModelRepository { * Repository for storing and managing Model instances */ protected readonly modelTabularRepository: TabularRepository< - typeof ModelSchema, + typeof ModelRecordSchema, typeof ModelPrimaryKeyNames >; constructor( - modelTabularRepository: TabularRepository + modelTabularRepository: TabularRepository ) { this.modelTabularRepository = modelTabularRepository; } diff --git a/packages/ai/src/model/ModelSchema.ts b/packages/ai/src/model/ModelSchema.ts index a65adb36..7df3fd6b 100644 --- a/packages/ai/src/model/ModelSchema.ts +++ b/packages/ai/src/model/ModelSchema.ts @@ -6,7 +6,14 @@ import { DataPortSchemaObject, FromSchema } from "@workglow/util"; -export const ModelSchema = { +/** + * A model configuration suitable for task/job inputs. + * + * @remarks + * This is intentionally less strict than {@link ModelRecord} so jobs can carry only the + * provider configuration required to execute, without requiring access to a model repository. + */ +export const ModelConfigSchema = { type: "object", properties: { model_id: { type: "string" }, @@ -17,10 +24,24 @@ export const ModelSchema = { providerConfig: { type: "object", default: {} }, metadata: { type: "object", default: {} }, }, + required: ["provider", "providerConfig"], + format: "model", + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +/** + * A fully-specified model record suitable for persistence in a repository. + */ +export const ModelRecordSchema = { + type: "object", + properties: { + ...ModelConfigSchema.properties, + }, required: ["model_id", "tasks", "provider", "title", "description", "providerConfig", "metadata"], format: "model", additionalProperties: false, } as const satisfies DataPortSchemaObject; -export type ModelRecord = FromSchema; +export type ModelConfig = FromSchema; +export type ModelRecord = FromSchema; export const ModelPrimaryKeyNames = ["model_id"] as const; diff --git a/packages/ai/src/provider/AiProviderRegistry.ts b/packages/ai/src/provider/AiProviderRegistry.ts index 633ca407..04f30487 100644 --- a/packages/ai/src/provider/AiProviderRegistry.ts +++ b/packages/ai/src/provider/AiProviderRegistry.ts @@ -6,7 +6,7 @@ import { TaskInput, TaskOutput } from "@workglow/task-graph"; import { globalServiceRegistry, WORKER_MANAGER } from "@workglow/util"; -import { ModelRecord } from "../model/ModelSchema"; +import type { ModelConfig } from "../model/ModelSchema"; /** * Type for the run function for the AiJob @@ -14,7 +14,7 @@ import { ModelRecord } from "../model/ModelSchema"; export type AiProviderRunFn< Input extends TaskInput = TaskInput, Output extends TaskOutput = TaskOutput, - Model extends ModelRecord = ModelRecord, + Model extends ModelConfig = ModelConfig, > = ( input: Input, model: Model | undefined, @@ -53,7 +53,7 @@ export class AiProviderRegistry { >(modelProvider: string, taskType: string) { const workerFn: AiProviderRunFn = async ( input: Input, - model: ModelRecord | undefined, + model: ModelConfig | undefined, update_progress: (progress: number, message?: string, details?: any) => void, signal?: AbortSignal ) => { diff --git a/packages/ai/src/task/base/AiTask.ts b/packages/ai/src/task/base/AiTask.ts index ec52f33c..1d447389 100644 --- a/packages/ai/src/task/base/AiTask.ts +++ b/packages/ai/src/task/base/AiTask.ts @@ -20,7 +20,7 @@ import { type JsonSchema } from "@workglow/util"; import { AiJob, AiJobInput } from "../../job/AiJob"; import { getGlobalModelRepository } from "../../model/ModelRegistry"; -import type { ModelRecord } from "../../model/ModelSchema"; +import type { ModelConfig, ModelRecord } from "../../model/ModelSchema"; function schemaFormat(schema: JsonSchema): string | undefined { return typeof schema === "object" && schema !== null && "format" in schema @@ -29,11 +29,11 @@ function schemaFormat(schema: JsonSchema): string | undefined { } export interface AiSingleTaskInput extends TaskInput { - model: string; + model: string | ModelConfig; } export interface AiArrayTaskInput extends TaskInput { - model: string | ModelRecord | (string | ModelRecord)[]; + model: string | ModelConfig | (string | ModelConfig)[]; } /** @@ -53,8 +53,16 @@ export class AiTask< * @param config - Configuration object for the task */ constructor(input: Input = {} as Input, config: Config = {} as Config) { + const modelLabel = + typeof input.model === "string" + ? input.model + : Array.isArray(input.model) + ? undefined + : typeof input.model === "object" && input.model + ? input.model.model_id || input.model.title || input.model.provider + : undefined; config.name ||= `${new.target.type || new.target.name}${ - input.model ? " with model " + input.model : "" + modelLabel ? " with model " + modelLabel : "" }`; super(input, config); } @@ -70,14 +78,14 @@ export class AiTask< * @returns The AiJobInput to submit to the queue */ protected override async getJobInput(input: Input): Promise> { - if (typeof input.model !== "string") { - console.error("AiTask: Model is not a string", input); + if (Array.isArray(input.model)) { + console.error("AiTask: Model is an array", input); throw new TaskConfigurationError( - "AiTask: Model is not a string, only create job for single model tasks" + "AiTask: Model is an array, only create job for single model tasks" ); } const runtype = (this.constructor as any).runtype ?? (this.constructor as any).type; - const model = await this.getModelForInput(input as AiSingleTaskInput); + const model = await this.getModelConfigForInput(input as AiSingleTaskInput); // TODO: if the queue is not memory based, we need to convert to something that can structure clone to the queue // const registeredQueue = await this.resolveQueue(input); @@ -86,10 +94,38 @@ export class AiTask< return { taskType: runtype, aiProvider: model.provider, - taskInput: input as Input & { model: string }, + taskInput: { ...(input as any), model } as Input & { model: ModelConfig }, }; } + /** + * Resolves a model configuration for the given input. + * + * @remarks + * - If `input.model` is a string, it is resolved via the global model repository. + * - If `input.model` is already a config object, it is used directly. + */ + protected async getModelConfigForInput(input: AiSingleTaskInput): Promise { + const modelValue = input.model; + if (!modelValue) throw new TaskConfigurationError("AiTask: No model found"); + if (typeof modelValue === "string") { + const modelname = modelValue; + if (this.modelCache && this.modelCache.name === modelname) { + return this.modelCache.model; + } + const model = await getGlobalModelRepository().findByName(modelname); + if (!model) { + throw new TaskConfigurationError(`AiTask: No model ${modelname} found`); + } + this.modelCache = { name: modelname, model }; + return model; + } + if (typeof modelValue === "object") { + return modelValue; + } + throw new TaskConfigurationError("AiTask: Invalid model value"); + } + /** * Creates a new Job instance for direct execution (without a queue). * @param input - The task input @@ -116,6 +152,9 @@ export class AiTask< protected async getModelForInput(input: AiSingleTaskInput): Promise { const modelname = input.model; if (!modelname) throw new TaskConfigurationError("AiTask: No model name found"); + if (typeof modelname !== "string") { + throw new TaskConfigurationError("AiTask: Model name is not a string"); + } if (this.modelCache && this.modelCache.name === modelname) { return this.modelCache.model; } @@ -132,6 +171,9 @@ export class AiTask< const model = await this.getModelForInput(input as AiSingleTaskInput); return model.provider; } + if (typeof input.model === "object" && input.model !== null && !Array.isArray(input.model)) { + return (input.model as ModelConfig).provider; + } return undefined; } @@ -159,11 +201,24 @@ export class AiTask< for (const [key, propSchema] of modelTaskProperties) { let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; for (const model of requestedModels) { - const foundModel = taskModels?.find((m) => m.model_id === model); - if (!foundModel) { - throw new TaskConfigurationError( - `AiTask: Missing model for '${key}' named '${model}' for task '${this.type}'` - ); + if (typeof model === "string") { + const foundModel = taskModels?.find((m) => m.model_id === model); + if (!foundModel) { + throw new TaskConfigurationError( + `AiTask: Missing model for '${key}' named '${model}' for task '${this.type}'` + ); + } + } else if (typeof model === "object" && model !== null) { + // Inline configs are accepted without requiring repository access. + // If 'tasks' is provided, do a best-effort compatibility check. + const tasks = (model as ModelConfig).tasks; + if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { + throw new TaskConfigurationError( + `AiTask: Inline model for '${key}' is not compatible with task '${this.type}'` + ); + } + } else { + throw new TaskConfigurationError(`AiTask: Invalid model for '${key}'`); } } } @@ -177,9 +232,17 @@ export class AiTask< for (const [key, propSchema] of modelPlainProperties) { let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; for (const model of requestedModels) { - const foundModel = await getGlobalModelRepository().findByName(model); - if (!foundModel) { - throw new TaskConfigurationError(`AiTask: Missing model for "${key}" named "${model}"`); + if (typeof model === "string") { + const foundModel = await getGlobalModelRepository().findByName(model); + if (!foundModel) { + throw new TaskConfigurationError( + `AiTask: Missing model for "${key}" named "${model}"` + ); + } + } else if (typeof model === "object" && model !== null) { + // Inline configs are accepted without requiring repository access. + } else { + throw new TaskConfigurationError(`AiTask: Invalid model for "${key}"`); } } } @@ -206,13 +269,21 @@ export class AiTask< const taskModels = await getGlobalModelRepository().findModelsByTask(this.type); for (const [key, propSchema] of modelTaskProperties) { let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - let usingModels = requestedModels.filter((model: string) => + const requestedStrings = requestedModels.filter( + (m: unknown): m is string => typeof m === "string" + ); + const requestedInline = requestedModels.filter( + (m: unknown): m is ModelConfig => typeof m === "object" && m !== null + ); + + const usingStrings = requestedStrings.filter((model: string) => taskModels?.find((m) => m.model_id === model) ); + const combined: (string | ModelConfig)[] = [...requestedInline, ...usingStrings]; + // we alter input to be the models that were found for this kind of input - usingModels = usingModels.length > 1 ? usingModels : usingModels[0]; - (input as any)[key] = usingModels; + (input as any)[key] = combined.length > 1 ? combined : combined[0]; } } return input; diff --git a/packages/ai/src/task/base/AiTaskSchemas.ts b/packages/ai/src/task/base/AiTaskSchemas.ts index 4229bbd5..67d2fcaf 100644 --- a/packages/ai/src/task/base/AiTaskSchemas.ts +++ b/packages/ai/src/task/base/AiTaskSchemas.ts @@ -11,7 +11,7 @@ import { FromSchemaOptions, JsonSchema, } from "@workglow/util"; -import { ModelSchema } from "../../model/ModelSchema"; +import { ModelConfigSchema } from "../../model/ModelSchema"; export type TypedArray = | Float64Array @@ -223,7 +223,7 @@ export function TypeModelByDetail< throw new Error("Invalid semantic value"); } return { - ...ModelSchema, + ...ModelConfigSchema, ...options, format: semantic, } as const satisfies JsonSchema; diff --git a/packages/ai/src/task/base/AiVisionTask.ts b/packages/ai/src/task/base/AiVisionTask.ts index 59af1839..16fbb7b4 100644 --- a/packages/ai/src/task/base/AiVisionTask.ts +++ b/packages/ai/src/task/base/AiVisionTask.ts @@ -12,15 +12,15 @@ import { JobQueueTaskConfig, TaskInput, type TaskOutput } from "@workglow/task-g import { convertImageDataToUseableForm, ImageDataSupport } from "@workglow/util"; import { AiJobInput } from "../../job/AiJob"; -import type { ModelRecord } from "../../model/ModelSchema"; +import type { ModelConfig } from "../../model/ModelSchema"; import { AiTask } from "./AiTask"; export interface AiVisionTaskSingleInput extends TaskInput { - model: string; + model: string | ModelConfig; } export interface AiVisionArrayTaskInput extends TaskInput { - model: string | ModelRecord | (string | ModelRecord)[]; + model: string | ModelConfig | (string | ModelConfig)[]; } /** diff --git a/packages/test/src/binding/IndexedDbModelRepository.ts b/packages/test/src/binding/IndexedDbModelRepository.ts index c411bc19..3c9acede 100644 --- a/packages/test/src/binding/IndexedDbModelRepository.ts +++ b/packages/test/src/binding/IndexedDbModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelPrimaryKeyNames, ModelRepository, ModelSchema } from "@workglow/ai"; +import { ModelPrimaryKeyNames, ModelRecordSchema, ModelRepository } from "@workglow/ai"; import { IndexedDbTabularRepository } from "@workglow/storage"; /** @@ -13,6 +13,6 @@ import { IndexedDbTabularRepository } from "@workglow/storage"; */ export class IndexedDbModelRepository extends ModelRepository { constructor(tableModels: string = "models") { - super(new IndexedDbTabularRepository(tableModels, ModelSchema, ModelPrimaryKeyNames)); + super(new IndexedDbTabularRepository(tableModels, ModelRecordSchema, ModelPrimaryKeyNames)); } } diff --git a/packages/test/src/binding/PostgresModelRepository.ts b/packages/test/src/binding/PostgresModelRepository.ts index 756bffcd..667439ab 100644 --- a/packages/test/src/binding/PostgresModelRepository.ts +++ b/packages/test/src/binding/PostgresModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelPrimaryKeyNames, ModelRepository, ModelSchema } from "@workglow/ai"; +import { ModelPrimaryKeyNames, ModelRecordSchema, ModelRepository } from "@workglow/ai"; import { PostgresTabularRepository } from "@workglow/storage"; import { Pool } from "pg"; @@ -14,6 +14,6 @@ import { Pool } from "pg"; */ export class PostgresModelRepository extends ModelRepository { constructor(db: Pool, tableModels: string = "aimodel") { - super(new PostgresTabularRepository(db, tableModels, ModelSchema, ModelPrimaryKeyNames)); + super(new PostgresTabularRepository(db, tableModels, ModelRecordSchema, ModelPrimaryKeyNames)); } } diff --git a/packages/test/src/binding/SqliteModelRepository.ts b/packages/test/src/binding/SqliteModelRepository.ts index d9398719..ebf6351f 100644 --- a/packages/test/src/binding/SqliteModelRepository.ts +++ b/packages/test/src/binding/SqliteModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelPrimaryKeyNames, ModelRepository, ModelSchema } from "@workglow/ai"; +import { ModelPrimaryKeyNames, ModelRecordSchema, ModelRepository } from "@workglow/ai"; import { SqliteTabularRepository } from "@workglow/storage"; /** @@ -13,6 +13,8 @@ import { SqliteTabularRepository } from "@workglow/storage"; */ export class SqliteModelRepository extends ModelRepository { constructor(dbOrPath: string, tableModels: string = "aimodel") { - super(new SqliteTabularRepository(dbOrPath, tableModels, ModelSchema, ModelPrimaryKeyNames)); + super( + new SqliteTabularRepository(dbOrPath, tableModels, ModelRecordSchema, ModelPrimaryKeyNames) + ); } } diff --git a/packages/test/src/test/ai-provider/AiProviderRegistry.test.ts b/packages/test/src/test/ai-provider/AiProviderRegistry.test.ts index 9ff9934c..61e140c7 100644 --- a/packages/test/src/test/ai-provider/AiProviderRegistry.test.ts +++ b/packages/test/src/test/ai-provider/AiProviderRegistry.test.ts @@ -9,7 +9,6 @@ import { AiJobInput, AiProviderRegistry, getAiProviderRegistry, - getGlobalModelRepository, setAiProviderRegistry, } from "@workglow/ai"; import { JobQueueClient, JobQueueServer, RateLimiter } from "@workglow/job-queue"; @@ -157,7 +156,7 @@ describe("AiProviderRegistry", () => { }); aiProviderRegistry.registerRunFn(TEST_PROVIDER, "text-generation", mockRunFn); - const model = await getGlobalModelRepository().addModel({ + const model = { model_id: "test:test-model:v1", title: "test-model", description: "test-model", @@ -168,7 +167,7 @@ describe("AiProviderRegistry", () => { modelPath: "test-model", }, metadata: {}, - }); + }; const controller = new AbortController(); const job = new AiJob({ @@ -176,7 +175,7 @@ describe("AiProviderRegistry", () => { input: { aiProvider: TEST_PROVIDER, taskType: "text-generation", - taskInput: { text: "test", model: "test:test-model:v1" }, + taskInput: { text: "test", model }, }, });