Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,21 +937,38 @@ def _export_diffusers_checkpoint(

print(f" Saved to: {component_export_dir}")

# Step 5: For pipelines, also save the model_index.json
# Step 5: For pipelines, also save model_index.json
if is_diffusers_pipe:
model_index_path = export_dir / "model_index.json"
if hasattr(pipe, "config") and pipe.config is not None:
# Save a simplified model_index.json that points to the exported components
is_partial_export = components is not None

# For full export, preserve original model_index.json when possible.
# For partial export, skip this to avoid listing non-exported components.
if not is_partial_export:
source_path = getattr(pipe, "name_or_path", None) or getattr(
getattr(pipe, "config", None), "_name_or_path", None
)
if source_path:
candidate_model_index = Path(source_path) / "model_index.json"
if candidate_model_index.exists():
with open(candidate_model_index) as file:
model_index = json.load(file)
with open(model_index_path, "w") as file:
json.dump(model_index, file, indent=4)

# Full-export fallback to Diffusers-native config serialization.
# Partial export skips this for the same reason as above.
if not is_partial_export and not model_index_path.exists() and hasattr(pipe, "save_config"):
pipe.save_config(export_dir)

# Last resort: synthesize a minimal model_index.json from exported components.
if not model_index_path.exists() and hasattr(pipe, "config") and pipe.config is not None:
model_index = {
"_class_name": type(pipe).__name__,
"_diffusers_version": diffusers.__version__,
}
# Add component class names for all components
# Use the base library name (e.g., "diffusers", "transformers") instead of
# the full module path, as expected by diffusers pipeline loading
for name, comp in all_components.items():
module = type(comp).__module__
# Extract base library name (first part of module path)
library = module.split(".")[0]
model_index[name] = [library, type(comp).__name__]

Expand Down