-
Notifications
You must be signed in to change notification settings - Fork 276
Description
Before submitting: I have searched through existing and past issues and confirmed this bug has not been reported.
Describe the bug
When using modelopt.torch.export.export_hf_checkpoint() to export a model quantized with NVFP4, the resulting safetensors files contain FP8 weights (float8_e4m3fn) instead of NVFP4 packed weights (uint8), resulting in 6.12 bits/param instead of the expected ~4.0 bits/param.
Impact: Blocker for vLLM deployment with true 4-bit quantization. Models are falsely labeled as NVFP4 but consume 50% more memory than expected.
Root Cause: In modelopt/torch/export/unified_export_hf.py at line 729, the _export_transformers_checkpoint function calls model.state_dict() which triggers PyTorch's tensor serialization. This decompresses the NVFP4 packed uint8 data back to FP8 (float8_e4m3fn). The quantization works correctly in memory, but the export path loses the 4-bit packing during serialization.
Evidence:
- Measured size: 5.7GB for 8B model = 6.12 bits/param (should be ~4.0-4.5)
- Tensor dtypes: 224
float8_e4m3fntensors (should beuint8) - FP8 export works correctly: Same code with FP8 produces 8.58 bits/param (correct)
- ONNX export works: NVFP4 → ONNX → TensorRT preserves true 4-bit quantization
Steps/Code to reproduce bug
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint
from pathlib import Path
# 1. Load model
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
trust_remote_code=True,
)
# 2. Quantize to NVFP4
print("Quantizing to NVFP4...")
mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop=None)
# 3. Export using official method
output_dir = Path("/tmp/test_nvfp4_export")
print("Exporting...")
export_hf_checkpoint(
model,
dtype=None, # Let modelopt preserve quantized dtype
export_dir=str(output_dir),
)Verification Code (Proof of Bug)
from safetensors import safe_open
from collections import Counter
model_dir = Path("/tmp/test_nvfp4_export")
dtypes = Counter()
for safetensor_file in sorted(model_dir.glob("*.safetensors")):
with safe_open(safetensor_file, framework="pt") as f:
for key in f.keys():
tensor = f.get_tensor(key)
dtype_str = str(tensor.dtype).replace("torch.", "")
dtypes[dtype_str] += 1
print("Dtype distribution:")
for dtype, count in sorted(dtypes.items(), key=lambda x: -x[1]):
print(f" {dtype}: {count} tensors")Output:
Dtype distribution:
uint8: 224 tensors # Scales/zero-points, not weights!
float8_e4m3fn: 224 tensors # ❌ Weights are FP8, not NVFP4!
float32: 224 tensors
float16: 67 tensors
Size Calculation
total_gb = 5.7
params = 8e9
bits_per_param = (total_gb * 1024**3 * 8) / params
print(f"Bits per parameter: {bits_per_param:.2f}")
# Output: Bits per parameter: 6.12
# Expected for NVFP4: ~4.0-4.5 bits/paramExpected behavior
The exported model should contain:
- Weight dtype:
uint8(NVFP4 packed format) - File size: ~4.3GB for 8B model (4.0-4.5 bits/param)
- True 4-bit quantization: 2 NVFP4 values packed per byte
- Consistency: Same behavior as FP8 export, which correctly preserves float8_e4m3fn dtype
Who can help?
@jenchen13 @ajrasane @cjluo-nv (based on NVFP4-related issue assignments)
System information
- Container used: None (native installation)
- OS: Ubuntu 22.04, Linux kernel 6.8
- CPU architecture: aarch64 (ARM64)
- GPU name: NVIDIA DGX Spark Blackwell GB10
- GPU memory size: 128GB unified memory
- Number of GPUs: 1
- Library versions:
- Python: 3.12
- ModelOpt version: 0.41.0 (latest as of Feb 8, 2026)
- CUDA: 13.0
- PyTorch: 2.11.0a0
- Transformers: 4.57.1
- TensorRT-LLM: Not used for this bug
- ONNXRuntime: Not used for this bug
- TensorRT: 10.x (present on system)
Additional Technical Details
Root Cause Analysis
After reviewing the source code in modelopt/torch/export/unified_export_hf.py:
-
Line 513 (
_export_quantized_weight): Correctly callsto_quantized_weight()which creates NVFP4 packed data viaNVFP4QTensor.quantize()[0]._quantized_data -
Line 555: Correctly sets the module weight to the packed NVFP4 tensor
-
Line 729 (
_export_transformers_checkpoint): Callsmodel.state_dict()to get the quantized weights -
Problem:
model.state_dict()triggers PyTorch's tensor serialization, which decompresses/converts NVFP4 packed uint8 data back to FP8 (float8_e4m3fn)
The NVFP4 quantization works correctly in memory, but the export path loses the 4-bit packing during state_dict() serialization.
Code References
In modelopt/torch/export/unified_export_hf.py:
Quantization (works correctly):
# Line 513-520 in _export_quantized_weight()
quantized_weight = to_quantized_weight(
weight.to(dtype),
weight_scale,
quantization_format, # QUANTIZATION_NVFP4
weight_scale_2,
block_size,
)
# This creates proper NVFP4 packed dataExport (loses NVFP4 packing):
# Line 729-740 in _export_transformers_checkpoint()
if accelerator is not None:
quantized_state_dict = accelerator.get_state_dict(model)
else:
quantized_state_dict = model.state_dict() # ❌ Decompresses NVFP4 here!
quantized_state_dict = postprocess_state_dict(
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)
return quantized_state_dict, quant_configIn modelopt/torch/export/quant_utils.py:
NVFP4 Quantization (works correctly):
# Line 920-938 in to_quantized_weight()
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
assert block_size is not None
assert weights_scaling_factor2 is not None
return NVFP4QTensor.quantize(
weight,
block_size,
weights_scaling_factor,
weights_scaling_factor2.view(-1, 1, 1)
if weights_scaling_factor2.dim() != 0
else weights_scaling_factor2,
)[0]._quantized_data # ✓ Returns packed uint8 NVFP4 dataThe issue is that this packed uint8 data gets converted back to FP8 somewhere in the state_dict() → model.save_pretrained() pipeline.
Attempted Fixes (All Failed)
We attempted three different approaches to bypass the bug:
-
Direct parameter export via
named_parameters(): Tried extracting weights withoutstate_dict(), but got FP16 (16.06 bits/param). NVFP4 data is stored in privateTensorQuantizerobjects, not accessible via PyTorch APIs. -
Intercept NVIDIA's internal pipeline: Tried importing
_process_quantized_modules()to capture data before compression. Failed with ImportError - private functions not exposed. -
Context manager extraction via
quantize_weight(): Tried using module-level context manager to access packed data. Still got FP16 (16.00 bits/param).
Conclusion: NVFP4 data is encapsulated in private objects and cannot be accessed without modifying nvidia-modelopt source code.
Working Workaround
Users must use the ONNX/TensorRT-LLM path for true NVFP4:
- Quantize to NVFP4 with modelopt
- Export to ONNX (not HuggingFace format)
- Build TensorRT engine
- Use TensorRT-LLM/Edge-LLM runtime
Limitations of workaround:
- No direct vLLM support
- Requires TensorRT engine build (memory intensive: 3-4x model size during build)
- Not compatible with HuggingFace ecosystem
- Limits deployment flexibility
Comparison: FP8 Export Works Correctly
FP8 export works correctly:
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=None)
export_hf_checkpoint(model, dtype=None, export_dir=str(output_dir))
# Result: 8.58 bits/param (correct for FP8 + overhead)Only NVFP4 export is broken. This proves the bug is specific to NVFP4 handling, not a general export issue.
Config Metadata Shows NVFP4 (But Weights Don't Match)
The exported model's hf_quant_config.json correctly states NVFP4:
{
"producer": {
"name": "modelopt",
"version": "0.41.0"
},
"quantization": {
"quant_algo": "NVFP4",
"kv_cache_quant_algo": null,
"group_size": 16,
"exclude_modules": ["lm_head"]
}
}But the actual weights are FP8, making this metadata incorrect and misleading.
Suggested Fix
The export_hf_checkpoint() function should preserve the NVFP4 packed uint8 format when saving to safetensors, similar to how it successfully preserves FP8.
Possible approaches:
- Direct export: Capture
._quantized_databeforestate_dict()decompression - Custom state_dict hook: Override tensor retrieval to return packed data
- Safetensors metadata: Mark tensors as packed NVFP4 to prevent decompression
- Documentation: If HuggingFace export of NVFP4 is unsupported, clearly document that NVFP4 → ONNX/TensorRT-LLM is the only supported path
Note: We have NVIDIA hardware (DGX Spark Blackwell GB10) and are happy to test patches or provide additional debugging information.