diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0a968e084..ca80cb450 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -937,21 +937,38 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") - # Step 5: For pipelines, also save the model_index.json + # Step 5: For pipelines, also save model_index.json if is_diffusers_pipe: model_index_path = export_dir / "model_index.json" - if hasattr(pipe, "config") and pipe.config is not None: - # Save a simplified model_index.json that points to the exported components + is_partial_export = components is not None + + # For full export, preserve original model_index.json when possible. + # For partial export, skip this to avoid listing non-exported components. + if not is_partial_export: + source_path = getattr(pipe, "name_or_path", None) or getattr( + getattr(pipe, "config", None), "_name_or_path", None + ) + if source_path: + candidate_model_index = Path(source_path) / "model_index.json" + if candidate_model_index.exists(): + with open(candidate_model_index) as file: + model_index = json.load(file) + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + # Full-export fallback to Diffusers-native config serialization. + # Partial export skips this for the same reason as above. + if not is_partial_export and not model_index_path.exists() and hasattr(pipe, "save_config"): + pipe.save_config(export_dir) + + # Last resort: synthesize a minimal model_index.json from exported components. + if not model_index_path.exists() and hasattr(pipe, "config") and pipe.config is not None: model_index = { "_class_name": type(pipe).__name__, "_diffusers_version": diffusers.__version__, } - # Add component class names for all components - # Use the base library name (e.g., "diffusers", "transformers") instead of - # the full module path, as expected by diffusers pipeline loading for name, comp in all_components.items(): module = type(comp).__module__ - # Extract base library name (first part of module path) library = module.split(".")[0] model_index[name] = [library, type(comp).__name__]