Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
Expand Down
79 changes: 55 additions & 24 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from modelopt.torch.utils import clear_cuda_cache

from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer
from .model_config import (
KV_CACHE_FP8,
KV_CACHE_INT8,
Expand Down Expand Up @@ -299,16 +299,19 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
return get_scaling_factor(weight_quantizer[0])

quantization_format = get_quantization_format(module)
# If NVFP4, we need to return quantized per_block scaling factors
if quantization_format in [

# Handle NVFP4 variants (static or dynamic)
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static or quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the quantization_format for NVFP4StaticQuantizer? do we need is_nvfp4_static here?

QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
# Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
if not is_nvfp4_static:
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
Expand All @@ -318,9 +321,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
weight_quantizer
)
return NVFP4QTensor.get_weights_scaling_factor(
# Unified method handles both static and dynamic quantizers
return NVFP4QTensor.get_weights_scaling_factor_from_quantizer(
weight_quantizer,
weight,
weight_quantizer.block_sizes[-1],
weight_scaling_factor_2.to(weight.device),
)[0]

Expand All @@ -343,27 +347,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")

quantization_format = get_quantization_format(module)

# Calibrate weight quantizer if amax is not set for all NVFP4 variants
if quantization_format in [
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static or quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
weight = getattr(module, weight_name)
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
# Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
if not is_nvfp4_static:
weight = getattr(module, weight_name)
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
return weight_quantizer._amax.float() / 448.0
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
return weight_quantizer._amax.float() / 448.0
else:
# Unified method handles both static and dynamic quantizers
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)

# SequentialQuantizer is required
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
Expand Down Expand Up @@ -735,7 +738,7 @@ def process_layer_quant_config(layer_config_dict):
layer_config = {"quant_algo": "W8A16"}
elif v == "int8_sq":
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
elif v == "nvfp4":
elif v in ["nvfp4", "nvfp4_static"]:
layer_config = {
"quant_algo": "NVFP4",
"group_size": block_size_value,
Expand Down Expand Up @@ -1339,6 +1342,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
for module in modules:
module.weight_quantizer[-1].amax = weight_amax

# Handle NVFP4StaticQuantizer: unify global_amax for fused layers
elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer):
global_amax_list = [
m.weight_quantizer.global_amax
for m in modules
if m.weight_quantizer.global_amax is not None
]
if global_amax_list:
unified_global_amax = torch.max(torch.stack(global_amax_list))
for module in modules:
module.weight_quantizer.global_amax = unified_global_amax

elif (
modules[0].weight_quantizer.is_enabled
and modules[0].weight_quantizer.amax is not None
Expand Down Expand Up @@ -1423,6 +1438,22 @@ def get_quant_config(
if block_size == 0:
block_size = get_weight_block_size(module)

# Static NVFP4 uses pre-computed per-block scales from MSE calibration
if quantization_format == QUANTIZATION_NVFP4:
weight_quantizer = getattr(module, "weight_quantizer", None)
if weight_quantizer is None:
# Try to get from first weight attribute
for wn in weight_names:
weight_quantizer = getattr(
module, quantizer_attr_names(wn).weight_quantizer, None
)
if weight_quantizer is not None:
break
if weight_quantizer is not None:
is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_static:
quantization_format = "nvfp4_static"

# Construct per layer config dictionary
layer_config_dict[name + ".quantization"] = quantization_format
layer_config_dict[name + ".awq_block_size"] = block_size
Expand Down
23 changes: 17 additions & 6 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
from torch.distributed.fsdp import FSDPModule

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.nn import (
NVFP4StaticQuantizer,
SequentialQuantizer,
TensorQuantizer,
)
from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names

Expand Down Expand Up @@ -502,11 +506,18 @@ def _export_quantized_weight(
weight, _ = maybe_transpose_expert_weight_dimensions(
weight, is_bmm_expert_weight=is_bmm_expert_weight
)
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]

# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)

if not is_nvfp4_static:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle the else condition?

# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
Comment on lines +509 to +520
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential shape mismatch for BMM-style expert weights with static NVFP4.

When is_bmm_expert_weight is True, the weight is transposed at Line 506–508 (e.g., from (E, in_dim, out_dim)(E, out_dim, in_dim)). The dynamic path (Lines 516–520) correctly recomputes weight_scale from the transposed weight. However, the static path skips recomputation and uses the scale that was computed by get_weight_scaling_factor (Line 461) from the untransposed weight.

Since the static path in NVFP4QTensor.get_weights_scaling_factor_from_quantizer (nvfp4_tensor.py Line 121–123) reshapes the per-block scale using the weight's original shape, the scale would have shape (*untransposed_shape[:-1], num_blocks) which won't match the transposed weight layout expected by to_quantized_weight.

This would fail with a shape error if static NVFP4 quantizers are ever used with Llama4TextExperts or GptOssExperts. If that combination is currently not expected, a guard would prevent a confusing error later:

Proposed guard
         # Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
         # For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
         is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
 
+        if is_nvfp4_static and is_bmm_expert_weight:
+            raise NotImplementedError(
+                "Static NVFP4 quantization is not yet supported for BMM-style expert weights "
+                "(Llama4TextExperts, GptOssExperts). Use dynamic NVFP4 quantization instead."
+            )
+
         if not is_nvfp4_static:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static and is_bmm_expert_weight:
raise NotImplementedError(
"Static NVFP4 quantization is not yet supported for BMM-style expert weights "
"(Llama4TextExperts, GptOssExperts). Use dynamic NVFP4 quantization instead."
)
if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
🤖 Prompt for AI Agents
In `@modelopt/torch/export/unified_export_hf.py` around lines 509 - 520, The
static NVFP4 path uses a scale computed from the untransposed weight which will
mismatch when is_bmm_expert_weight is True; update the branch that checks
is_nvfp4_static to either (1) recompute/reshape weight_scale using the
transposed weight by calling NVFP4QTensor.get_weights_scaling_factor (or
get_weights_scaling_factor_from_quantizer) on the transposed weight before
calling to_quantized_weight, or (2) add an explicit guard that raises a clear
error when is_bmm_expert_weight and isinstance(weight_quantizer,
NVFP4StaticQuantizer) to prevent misuse; modify the code around the
is_nvfp4_static check (referencing is_bmm_expert_weight, weight_quantizer,
NVFP4StaticQuantizer, NVFP4QTensor.get_weights_scaling_factor[_from_quantizer],
and to_quantized_weight) accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv could you fix this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!


quantized_weight = to_quantized_weight(
weight.to(dtype),
Expand Down
23 changes: 23 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@
"algorithm": "max",
}

NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}
Comment on lines +391 to +411
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's check the file exists and get context around lines 391-411 and 649-678
wc -l modelopt/torch/quantization/config.py

# Check if the new config exists at lines 391-411
echo "=== Lines 391-420 (new config) ==="
sed -n '391,420p' modelopt/torch/quantization/config.py

# Check the choices set around lines 649-678
echo -e "\n=== Lines 645-680 (choices set) ==="
sed -n '645,680p' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1888


🏁 Script executed:

#!/bin/bash
# Check how choices set is used in the codebase
echo "=== Usage of 'choices' in config.py ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/config.py

echo -e "\n=== References to choices in broader codebase ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/ | grep -v "config.py" | head -20

echo -e "\n=== Imports of choices from config ==="
rg -n --type=py 'from.*config.*import.*choices' modelopt/ | head -10

echo -e "\n=== Direct config.choices references ==="
rg -n --type=py 'config\.choices' modelopt/ | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 2136


🏁 Script executed:

#!/bin/bash
# Check how the choices variable is validated/used
echo "=== Search for validation against choices ==="
rg -n --type=py -A2 -B2 'in.*choices' modelopt/torch/quantization/ | head -40

echo -e "\n=== Check if choices is exported or used at module level ==="
rg -n --type=py 'choices.*=' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3561


🏁 Script executed:

#!/bin/bash
# Check if there are comments explaining why the new config isn't in choices
echo "=== Context before the new config (lines 380-395) ==="
sed -n '380,395p' modelopt/torch/quantization/config.py

# Check if there's anything after the new config or before choices that explains the pattern
echo -e "\n=== Check around the choices set definition (lines 640-660) ==="
sed -n '640,660p' modelopt/torch/quantization/config.py

# See all NVFP4 configs to understand the pattern
echo -e "\n=== All NVFP4 config definitions (look for pattern) ==="
rg -n 'NVFP4.*_CFG = \{' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1633


🏁 Script executed:

#!/bin/bash
# Check if this is mentioned as experimental/temporary in any comments around these lines
echo "=== Comments around the new config (lines 385-415) ==="
sed -n '385,415p' modelopt/torch/quantization/config.py | grep -E '^\s*(#|""")'

# Check if choices set has any explanatory comments
echo -e "\n=== Comments around choices set definition ==="
sed -n '645,655p' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 387


Add NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG to the choices set.

The new config is not registered in the choices set (line 649), which is documented as containing all "supported quantization format names" and is used in algorithms.py:126 for config discovery. All other similar NVFP4 configs are included in choices. Add the new config name to maintain consistency and ensure it's discoverable through the documented API.

🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/config.py` around lines 391 - 411, The new
NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG config is defined but not added to the
exported choices set; update the choices collection (the variable named choices)
to include "NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG" alongside the other NVFP4
entries so it becomes discoverable by algorithms.py (which expects all supported
quantization format names in choices). Locate the choices definition and append
the new config name in the same style/ordering as the other NVFP4_* entries.



NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down
89 changes: 84 additions & 5 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,87 @@ def get_e2m1_bounds(cls, device):
cls.e2m1_bounds_on_device[device] = e2m1_bounds.to(device)
return cls.e2m1_bounds_on_device[device]

@classmethod
def _is_static_quantizer(cls, weight_quantizer) -> bool:
"""Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None

@classmethod
def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a unittest?

"""Returns per tensor weight scaling factor from the weight_quantizer amax."""
# Assert that weight_quantizer has attribute amax
assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax"
return weight_quantizer._amax.float() / (6.0 * 448.0)
"""Returns per tensor weight scaling factor from the weight_quantizer.

Handles both static NVFP4 quantizers (using global_amax) and
dynamic quantizers (using _amax).

Args:
weight_quantizer: The weight quantizer (static or dynamic).

Returns:
The global scaling factor as a float tensor.
"""
if cls._is_static_quantizer(weight_quantizer):
return weight_quantizer.global_amax.float() / (6.0 * 448.0)
else:
assert hasattr(weight_quantizer, "_amax"), (
"Weight quantizer does not have attribute amax"
)
return weight_quantizer._amax.float() / (6.0 * 448.0)

@classmethod
def get_weights_scaling_factor_from_quantizer(
cls,
weight_quantizer,
weight: torch.Tensor,
weights_scaling_factor_2: torch.Tensor | None = None,
keep_high_precision: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns quantized per block weight scaling factor from quantizer.

Handles both static NVFP4 quantizers (with pre-computed per-block amax)
and dynamic quantizers (computing from weight tensor).

Args:
weight_quantizer: The weight quantizer (static or dynamic).
weight: The weight tensor (used for shape in static, values in dynamic).
weights_scaling_factor_2: Optional pre-computed global scale.
keep_high_precision: Whether to keep scales in high precision.

Returns:
Tuple of (per_block_scale, weights_scaling_factor_2).
"""
block_size = weight_quantizer.block_sizes[-1]

if weights_scaling_factor_2 is None:
weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
weight_quantizer
)

if cls._is_static_quantizer(weight_quantizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in static NVFP4 quant case:

You use _amax for per-block amax and global_amax for the per-tensor amax?

# Static path: use pre-computed per-block amax values from quantizer
global_amax = weight_quantizer.global_amax.float()
per_block_amax = weight_quantizer._amax.float()

# Compute scales in float
per_block_scale_max = global_amax / 6.0
per_block_scale = per_block_amax / 6.0
per_block_scale[per_block_scale == 0] = 1.0

# Reshape per_block_scale to match weight's block structure
num_blocks_per_row = weight.shape[-1] // block_size
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)

# Quantize scales to FP8
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
)
return per_block_scale, weights_scaling_factor_2
else:
# Dynamic path: compute from weight tensor
return cls.get_weights_scaling_factor(
weight, block_size, weights_scaling_factor_2, keep_high_precision
)

@classmethod
def get_weights_scaling_factor(
Expand All @@ -67,7 +142,11 @@ def get_weights_scaling_factor(
weights_scaling_factor_2: torch.Tensor | None = None,
keep_high_precision: bool = False,
):
"""Returns quantized per block weight scaling factor."""
"""Returns quantized per block weight scaling factor from weight tensor.

This is the dynamic path that computes scales directly from the weight values.
For quantizers with pre-computed amax, use get_weights_scaling_factor_from_quantizer.
"""
if weights_scaling_factor_2 is None:
weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input)

Expand Down
Loading