Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ venv/
**.pickle
**.tar.gz
**.nemo

# Ignore experiment run
examples/diffusers/quantization/experiment_run
22 changes: 22 additions & 0 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@
"algorithm": "max",
}

NVFP4_ASYMMETRIC_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"bias": {-1: None, "type": "static", "method": "mean"},
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"bias": {-1: None, "type": "dynamic", "method": "mean"}, #bias must be dynamic
},
"*output_quantizer": {"enable": False},
"default": {"enable": False},
},
"algorithm": "max",
}

NVFP4_FP8_MHA_CONFIG = {
"quant_cfg": {
"**weight_quantizer": {
Expand Down
15 changes: 15 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
filter_func_default,
filter_func_flux_dev,
filter_func_ltx_video,
filter_func_qwen_image,
filter_func_wan_video,
)

Expand All @@ -46,6 +47,7 @@ class ModelType(str, Enum):
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"
QWEN_IMAGE_2512 = "qwen-image-2512"


def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
Expand All @@ -69,6 +71,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
ModelType.QWEN_IMAGE_2512: filter_func_qwen_image,
}

return filter_func_map.get(model_type, filter_func_default)
Expand All @@ -86,6 +89,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
ModelType.QWEN_IMAGE_2512: "Qwen/Qwen-Image-2512",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
Expand All @@ -99,6 +103,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
ModelType.QWEN_IMAGE_2512: DiffusionPipeline,
}

# Shared dataset configurations
Expand Down Expand Up @@ -208,6 +213,16 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
),
},
},
ModelType.QWEN_IMAGE_2512: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 1024,
"width": 1024,
"guidance_scale": 4.0,
"negative_prompt": "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。", # noqa: RUF001
},
},
}


Expand Down
22 changes: 14 additions & 8 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
FP8_DEFAULT_CONFIG,
INT8_DEFAULT_CONFIG,
NVFP4_DEFAULT_CONFIG,
NVFP4_ASYMMETRIC_CONFIG,
NVFP4_FP8_MHA_CONFIG,
reset_set_int8_config,
set_quant_config_attr,
)
from diffusers import DiffusionPipeline
from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd
# from onnx_utils.export import generate_fp8_scales, modelopt_export_sd
from pipeline_manager import PipelineManager
from quantize_config import (
CalibrationConfig,
Expand Down Expand Up @@ -133,7 +134,7 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
if self.model_config.model_type.value.startswith("flux"):
quant_config = NVFP4_FP8_MHA_CONFIG
else:
quant_config = NVFP4_DEFAULT_CONFIG
quant_config = NVFP4_ASYMMETRIC_CONFIG
else:
raise NotImplementedError(f"Unknown format {self.config.format}")
if self.config.quantize_mha:
Expand Down Expand Up @@ -228,8 +229,12 @@ def save_checkpoint(self, backbone: torch.nn.Module) -> None:
return

ckpt_path = self.config.quantized_torch_ckpt_path
ckpt_path.mkdir(parents=True, exist_ok=True)
target_path = ckpt_path / "backbone.pt"
if ckpt_path.suffix == ".pt":
target_path = ckpt_path
target_path.parent.mkdir(parents=True, exist_ok=True)
else:
ckpt_path.mkdir(parents=True, exist_ok=True)
target_path = ckpt_path / "backbone.pt"
self.logger.info(f"Saving backbone to {target_path}")
mto.save(backbone, str(target_path))

Expand Down Expand Up @@ -260,7 +265,8 @@ def export_onnx(
self.logger.info(
"Detected quantizing conv layers in backbone. Generating FP8 scales..."
)
generate_fp8_scales(backbone)
# TODO: needs a fix, commenting out for now
# generate_fp8_scales(backbone)
self.logger.info("Preparing models for export...")
pipe.to("cpu")
torch.cuda.empty_cache()
Expand All @@ -269,9 +275,9 @@ def export_onnx(
backbone.eval()
with torch.no_grad():
self.logger.info("Exporting to ONNX...")
modelopt_export_sd(
backbone, str(self.config.onnx_dir), model_type.value, quant_format.value
)
# modelopt_export_sd(
# backbone, str(self.config.onnx_dir), model_type.value, quant_format.value
# )

self.logger.info("ONNX export completed successfully")

Expand Down
8 changes: 8 additions & 0 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def filter_func_default(name: str) -> bool:
return pattern.match(name) is not None


def filter_func_qwen_image(name: str) -> bool:
"""Qwen-Image filter: disable the 5 standalone modules outside transformer blocks (time_text_embed covers 2 sublayers)."""
pattern = re.compile(
r".*(time_text_embed|img_in|txt_in|norm_out|proj_out).*"
)
return pattern.match(name) is not None


def check_conv_and_mha(backbone, if_fp4, quantize_mha):
for name, module in backbone.named_modules():
if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and if_fp4:
Expand Down