Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ import type {
} from "@workglow/ai";
import { CallbackStatus } from "./HFT_CallbackStatus";
import { HTF_CACHE_NAME } from "./HFT_Constants";
import { HfTransformersOnnxModelRecord } from "./HFT_ModelSchema";
import { HfTransformersOnnxModelConfig } from "./HFT_ModelSchema";

const pipelines = new Map<string, any>();

Expand All @@ -93,12 +93,12 @@ export function clearPipelineCache(): void {
* @param progressScaleMax - Maximum progress value for download phase (100 for download-only, 10 for download+run)
*/
const getPipeline = async (
model: HfTransformersOnnxModelRecord,
model: HfTransformersOnnxModelConfig,
onProgress: (progress: number, message?: string, details?: any) => void,
options: PretrainedModelOptions = {},
progressScaleMax: number = 10
) => {
const cacheKey = `${model.model_id}:${model.providerConfig.pipeline}`;
const cacheKey = `${model.providerConfig.modelPath}:${model.providerConfig.pipeline}`;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline cache key ignores device and dtype configuration

The pipeline cache key changed from model.model_id:pipeline to model.providerConfig.modelPath:pipeline. However, the pipeline is created with additional configuration options including dType (dtype) and device (lines 372, 376). Since these options aren't included in the cache key, two model configs with the same modelPath and pipeline but different dType or device settings will incorrectly share the same cached pipeline. This could cause models to run on the wrong device (e.g., CPU instead of WebGPU) or with the wrong precision, leading to incorrect behavior or errors.

Fix in Cursor Fix in Web

if (pipelines.has(cacheKey)) {
return pipelines.get(cacheKey);
}
Expand Down Expand Up @@ -433,7 +433,7 @@ const getPipeline = async (
export const HFT_Download: AiProviderRunFn<
DownloadModelTaskExecuteInput,
DownloadModelTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
// Download the model by creating a pipeline
// Use 100 as progressScaleMax since this is download-only (0-100%)
Expand All @@ -451,11 +451,12 @@ export const HFT_Download: AiProviderRunFn<
export const HFT_Unload: AiProviderRunFn<
UnloadModelTaskExecuteInput,
UnloadModelTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
// Delete the pipeline from the in-memory map
if (pipelines.has(model!.model_id)) {
pipelines.delete(model!.model_id);
const cacheKey = `${model!.providerConfig.modelPath}:${model!.providerConfig.pipeline}`;
if (pipelines.has(cacheKey)) {
pipelines.delete(cacheKey);
onProgress(50, "Pipeline removed from memory");
}

Expand Down Expand Up @@ -515,7 +516,7 @@ const deleteModelCache = async (modelPath: string): Promise<void> => {
export const HFT_TextEmbedding: AiProviderRunFn<
TextEmbeddingTaskExecuteInput,
TextEmbeddingTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const generateEmbedding: FeatureExtractionPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -546,7 +547,7 @@ export const HFT_TextEmbedding: AiProviderRunFn<
export const HFT_TextClassification: AiProviderRunFn<
TextClassificationTaskExecuteInput,
TextClassificationTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
if (model?.providerConfig?.pipeline === "zero-shot-classification") {
if (
Expand Down Expand Up @@ -602,7 +603,7 @@ export const HFT_TextClassification: AiProviderRunFn<
export const HFT_TextLanguageDetection: AiProviderRunFn<
TextLanguageDetectionTaskExecuteInput,
TextLanguageDetectionTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const TextClassification: TextClassificationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -632,7 +633,7 @@ export const HFT_TextLanguageDetection: AiProviderRunFn<
export const HFT_TextNamedEntityRecognition: AiProviderRunFn<
TextNamedEntityRecognitionTaskExecuteInput,
TextNamedEntityRecognitionTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const textNamedEntityRecognition: TokenClassificationPipeline = await getPipeline(
model!,
Expand Down Expand Up @@ -663,7 +664,7 @@ export const HFT_TextNamedEntityRecognition: AiProviderRunFn<
export const HFT_TextFillMask: AiProviderRunFn<
TextFillMaskTaskExecuteInput,
TextFillMaskTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const unmasker: FillMaskPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -691,7 +692,7 @@ export const HFT_TextFillMask: AiProviderRunFn<
export const HFT_TextGeneration: AiProviderRunFn<
TextGenerationTaskExecuteInput,
TextGenerationTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -724,7 +725,7 @@ export const HFT_TextGeneration: AiProviderRunFn<
export const HFT_TextTranslation: AiProviderRunFn<
TextTranslationTaskExecuteInput,
TextTranslationTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const translate: TranslationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -758,7 +759,7 @@ export const HFT_TextTranslation: AiProviderRunFn<
export const HFT_TextRewriter: AiProviderRunFn<
TextRewriterTaskExecuteInput,
TextRewriterTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -798,7 +799,7 @@ export const HFT_TextRewriter: AiProviderRunFn<
export const HFT_TextSummary: AiProviderRunFn<
TextSummaryTaskExecuteInput,
TextSummaryTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const generateSummary: SummarizationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -829,7 +830,7 @@ export const HFT_TextSummary: AiProviderRunFn<
export const HFT_TextQuestionAnswer: AiProviderRunFn<
TextQuestionAnswerTaskExecuteInput,
TextQuestionAnswerTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
// Get the question answering pipeline
const generateAnswer: QuestionAnsweringPipeline = await getPipeline(model!, onProgress, {
Expand Down Expand Up @@ -860,7 +861,7 @@ export const HFT_TextQuestionAnswer: AiProviderRunFn<
export const HFT_ImageSegmentation: AiProviderRunFn<
ImageSegmentationTaskExecuteInput,
ImageSegmentationTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const segmenter: ImageSegmentationPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand Down Expand Up @@ -893,7 +894,7 @@ export const HFT_ImageSegmentation: AiProviderRunFn<
export const HFT_ImageToText: AiProviderRunFn<
ImageToTextTaskExecuteInput,
ImageToTextTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const captioner: ImageToTextPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand All @@ -917,7 +918,7 @@ export const HFT_ImageToText: AiProviderRunFn<
export const HFT_BackgroundRemoval: AiProviderRunFn<
BackgroundRemovalTaskExecuteInput,
BackgroundRemovalTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const remover: BackgroundRemovalPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand All @@ -940,7 +941,7 @@ export const HFT_BackgroundRemoval: AiProviderRunFn<
export const HFT_ImageEmbedding: AiProviderRunFn<
ImageEmbeddingTaskExecuteInput,
ImageEmbeddingTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
const embedder: ImageFeatureExtractionPipeline = await getPipeline(model!, onProgress, {
abort_signal: signal,
Expand All @@ -960,7 +961,7 @@ export const HFT_ImageEmbedding: AiProviderRunFn<
export const HFT_ImageClassification: AiProviderRunFn<
ImageClassificationTaskExecuteInput,
ImageClassificationTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
if (model?.providerConfig?.pipeline === "zero-shot-image-classification") {
if (!input.categories || !Array.isArray(input.categories) || input.categories.length === 0) {
Expand Down Expand Up @@ -1015,7 +1016,7 @@ export const HFT_ImageClassification: AiProviderRunFn<
export const HFT_ObjectDetection: AiProviderRunFn<
ObjectDetectionTaskExecuteInput,
ObjectDetectionTaskExecuteOutput,
HfTransformersOnnxModelRecord
HfTransformersOnnxModelConfig
> = async (input, model, onProgress, signal) => {
if (model?.providerConfig?.pipeline === "zero-shot-object-detection") {
if (!input.labels || !Array.isArray(input.labels) || input.labels.length === 0) {
Expand Down Expand Up @@ -1070,7 +1071,6 @@ function imageToBase64(image: RawImage): string {
return (image as any).toBase64?.() || "";
}


/**
* Create a text streamer for a given tokenizer and update progress function
* @param tokenizer - The tokenizer to use for the streamer
Expand Down
22 changes: 17 additions & 5 deletions packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { ModelSchema } from "@workglow/ai";
import { ModelConfigSchema, ModelRecordSchema } from "@workglow/ai";
import { DataPortSchemaObject, FromSchema } from "@workglow/util";
import { HF_TRANSFORMERS_ONNX, PipelineUseCase, QuantizationDataType } from "./HFT_Constants";

Expand Down Expand Up @@ -89,14 +89,26 @@ export const HfTransformersOnnxModelSchema = {
additionalProperties: true,
} as const satisfies DataPortSchemaObject;

const ExtendedModelSchema = {
const ExtendedModelRecordSchema = {
type: "object",
properties: {
...ModelSchema.properties,
...ModelRecordSchema.properties,
...HfTransformersOnnxModelSchema.properties,
},
required: [...ModelSchema.required, ...HfTransformersOnnxModelSchema.required],
required: [...ModelRecordSchema.required, ...HfTransformersOnnxModelSchema.required],
additionalProperties: false,
} as const satisfies DataPortSchemaObject;

export type HfTransformersOnnxModelRecord = FromSchema<typeof ExtendedModelSchema>;
export type HfTransformersOnnxModelRecord = FromSchema<typeof ExtendedModelRecordSchema>;

const ExtendedModelConfigSchema = {
type: "object",
properties: {
...ModelConfigSchema.properties,
...HfTransformersOnnxModelSchema.properties,
},
required: [...ModelConfigSchema.required, ...HfTransformersOnnxModelSchema.required],
additionalProperties: false,
} as const satisfies DataPortSchemaObject;

export type HfTransformersOnnxModelConfig = FromSchema<typeof ExtendedModelConfigSchema>;
34 changes: 17 additions & 17 deletions packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import type {
UnloadModelTaskExecuteOutput,
} from "@workglow/ai";
import { PermanentJobError } from "@workglow/job-queue";
import { TFMPModelRecord } from "./TFMP_ModelSchema";
import { TFMPModelConfig } from "./TFMP_ModelSchema";

interface TFMPWasmFileset {
/** The path to the Wasm loader script. */
Expand Down Expand Up @@ -82,7 +82,7 @@ const wasm_reference_counts = new Map<string, number>();
* Helper function to get a WASM task for a model
*/
const getWasmTask = async (
model: TFMPModelRecord,
model: TFMPModelConfig,
onProgress: (progress: number, message?: string, details?: any) => void,
signal: AbortSignal
): Promise<TFMPWasmFileset> => {
Expand Down Expand Up @@ -213,7 +213,7 @@ const optionsMatch = (opts1: Record<string, unknown>, opts2: Record<string, unkn
};

const getModelTask = async <T extends TaskType>(
model: TFMPModelRecord,
model: TFMPModelConfig,
options: Record<string, unknown>,
onProgress: (progress: number, message?: string, details?: any) => void,
signal: AbortSignal,
Expand Down Expand Up @@ -264,7 +264,7 @@ const getModelTask = async <T extends TaskType>(
export const TFMP_Download: AiProviderRunFn<
DownloadModelTaskExecuteInput,
DownloadModelTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
let task: TextEmbedder | TextClassifier | LanguageDetector;
switch (model?.providerConfig.pipeline) {
Expand Down Expand Up @@ -298,7 +298,7 @@ export const TFMP_Download: AiProviderRunFn<
export const TFMP_TextEmbedding: AiProviderRunFn<
TextEmbeddingTaskExecuteInput,
TextEmbeddingTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const textEmbedder = await getModelTask(model!, {}, onProgress, signal, TextEmbedder);
const result = textEmbedder.embed(input.text);
Expand All @@ -321,7 +321,7 @@ export const TFMP_TextEmbedding: AiProviderRunFn<
export const TFMP_TextClassification: AiProviderRunFn<
TextClassificationTaskExecuteInput,
TextClassificationTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const TextClassification = await getModelTask(
model!,
Expand Down Expand Up @@ -358,7 +358,7 @@ export const TFMP_TextClassification: AiProviderRunFn<
export const TFMP_TextLanguageDetection: AiProviderRunFn<
TextLanguageDetectionTaskExecuteInput,
TextLanguageDetectionTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const maxLanguages = input.maxLanguages === 0 ? -1 : input.maxLanguages;

Expand Down Expand Up @@ -402,7 +402,7 @@ export const TFMP_TextLanguageDetection: AiProviderRunFn<
export const TFMP_Unload: AiProviderRunFn<
UnloadModelTaskExecuteInput,
UnloadModelTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const modelPath = model!.providerConfig.modelPath;
onProgress(10, "Unloading model");
Expand Down Expand Up @@ -442,7 +442,7 @@ export const TFMP_Unload: AiProviderRunFn<
export const TFMP_ImageSegmentation: AiProviderRunFn<
ImageSegmentationTaskExecuteInput,
ImageSegmentationTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const imageSegmenter = await getModelTask(model!, {}, onProgress, signal, ImageSegmenter);
const result = imageSegmenter.segment(input.image as any);
Expand Down Expand Up @@ -475,7 +475,7 @@ export const TFMP_ImageSegmentation: AiProviderRunFn<
export const TFMP_ImageEmbedding: AiProviderRunFn<
ImageEmbeddingTaskExecuteInput,
ImageEmbeddingTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const imageEmbedder = await getModelTask(model!, {}, onProgress, signal, ImageEmbedder);
const result = imageEmbedder.embed(input.image as any);
Expand All @@ -497,7 +497,7 @@ export const TFMP_ImageEmbedding: AiProviderRunFn<
export const TFMP_ImageClassification: AiProviderRunFn<
ImageClassificationTaskExecuteInput,
ImageClassificationTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const imageClassifier = await getModelTask(
model!,
Expand Down Expand Up @@ -530,7 +530,7 @@ export const TFMP_ImageClassification: AiProviderRunFn<
export const TFMP_ObjectDetection: AiProviderRunFn<
ObjectDetectionTaskExecuteInput,
ObjectDetectionTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const objectDetector = await getModelTask(
model!,
Expand Down Expand Up @@ -569,7 +569,7 @@ export const TFMP_ObjectDetection: AiProviderRunFn<
export const TFMP_GestureRecognizer: AiProviderRunFn<
GestureRecognizerTaskExecuteInput,
GestureRecognizerTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const gestureRecognizer = await getModelTask(
model!,
Expand Down Expand Up @@ -621,7 +621,7 @@ export const TFMP_GestureRecognizer: AiProviderRunFn<
export const TFMP_HandLandmarker: AiProviderRunFn<
HandLandmarkerTaskExecuteInput,
HandLandmarkerTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const handLandmarker = await getModelTask(
model!,
Expand Down Expand Up @@ -669,7 +669,7 @@ export const TFMP_HandLandmarker: AiProviderRunFn<
export const TFMP_FaceDetector: AiProviderRunFn<
FaceDetectorTaskExecuteInput,
FaceDetectorTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const faceDetector = await getModelTask(
model!,
Expand Down Expand Up @@ -714,7 +714,7 @@ export const TFMP_FaceDetector: AiProviderRunFn<
export const TFMP_FaceLandmarker: AiProviderRunFn<
FaceLandmarkerTaskExecuteInput,
FaceLandmarkerTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const faceLandmarker = await getModelTask(
model!,
Expand Down Expand Up @@ -770,7 +770,7 @@ export const TFMP_FaceLandmarker: AiProviderRunFn<
export const TFMP_PoseLandmarker: AiProviderRunFn<
PoseLandmarkerTaskExecuteInput,
PoseLandmarkerTaskExecuteOutput,
TFMPModelRecord
TFMPModelConfig
> = async (input, model, onProgress, signal) => {
const poseLandmarker = await getModelTask(
model!,
Expand Down
Loading