diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a5af5e97d..8691a2db2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -80,6 +80,7 @@ "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, + "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 0d99d44f0..90e4f89c2 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -45,7 +45,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import SequentialQuantizer, TensorQuantizer +from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -299,16 +299,19 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) - # If NVFP4, we need to return quantized per_block scaling factors - if quantization_format in [ + + # Handle NVFP4 variants (static or dynamic) + is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_nvfp4_static or quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) + if not is_nvfp4_static: + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. @@ -318,9 +321,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - return NVFP4QTensor.get_weights_scaling_factor( + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, weight, - weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -343,27 +347,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") quantization_format = get_quantization_format(module) - # Calibrate weight quantizer if amax is not set for all NVFP4 variants - if quantization_format in [ + is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_nvfp4_static or quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - weight = getattr(module, weight_name) - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) + if not is_nvfp4_static: + weight = getattr(module, weight_name) + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format in [ - QUANTIZATION_NVFP4, - QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_NVFP4_SVDQUANT, - ]: - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 + else: + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -735,7 +738,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v == "nvfp4": + elif v in ["nvfp4", "nvfp4_static"]: layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1339,6 +1342,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax + # Handle NVFP4StaticQuantizer: unify global_amax for fused layers + elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): + global_amax_list = [ + m.weight_quantizer.global_amax + for m in modules + if m.weight_quantizer.global_amax is not None + ] + if global_amax_list: + unified_global_amax = torch.max(torch.stack(global_amax_list)) + for module in modules: + module.weight_quantizer.global_amax = unified_global_amax + elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1423,6 +1438,22 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) + # Static NVFP4 uses pre-computed per-block scales from MSE calibration + if quantization_format == QUANTIZATION_NVFP4: + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + # Try to get from first weight attribute + for wn in weight_names: + weight_quantizer = getattr( + module, quantizer_attr_names(wn).weight_quantizer, None + ) + if weight_quantizer is not None: + break + if weight_quantizer is not None: + is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_static: + quantization_format = "nvfp4_static" + # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 61bebb51d..ac66eac96 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -50,7 +50,11 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn import ( + NVFP4StaticQuantizer, + SequentialQuantizer, + TensorQuantizer, +) from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -502,11 +506,18 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( - weight, - block_size=block_size, - weights_scaling_factor_2=weight_scale_2, - )[0] + + # Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration) + # For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor + is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + + if not is_nvfp4_static: + # For dynamic NVFP4, compute scales from weights + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size=block_size, + weights_scaling_factor_2=weight_scale_2, + )[0] quantized_weight = to_quantized_weight( weight.to(dtype), diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..956218e60 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -388,6 +388,29 @@ "algorithm": "max", } +NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, + }, +} + + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 2ff1b17e9..6ff31424c 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -52,12 +52,87 @@ def get_e2m1_bounds(cls, device): cls.e2m1_bounds_on_device[device] = e2m1_bounds.to(device) return cls.e2m1_bounds_on_device[device] + @classmethod + def _is_static_quantizer(cls, weight_quantizer) -> bool: + """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" + return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None + @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): - """Returns per tensor weight scaling factor from the weight_quantizer amax.""" - # Assert that weight_quantizer has attribute amax - assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" - return weight_quantizer._amax.float() / (6.0 * 448.0) + """Returns per tensor weight scaling factor from the weight_quantizer. + + Handles both static NVFP4 quantizers (using global_amax) and + dynamic quantizers (using _amax). + + Args: + weight_quantizer: The weight quantizer (static or dynamic). + + Returns: + The global scaling factor as a float tensor. + """ + if cls._is_static_quantizer(weight_quantizer): + return weight_quantizer.global_amax.float() / (6.0 * 448.0) + else: + assert hasattr(weight_quantizer, "_amax"), ( + "Weight quantizer does not have attribute amax" + ) + return weight_quantizer._amax.float() / (6.0 * 448.0) + + @classmethod + def get_weights_scaling_factor_from_quantizer( + cls, + weight_quantizer, + weight: torch.Tensor, + weights_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Returns quantized per block weight scaling factor from quantizer. + + Handles both static NVFP4 quantizers (with pre-computed per-block amax) + and dynamic quantizers (computing from weight tensor). + + Args: + weight_quantizer: The weight quantizer (static or dynamic). + weight: The weight tensor (used for shape in static, values in dynamic). + weights_scaling_factor_2: Optional pre-computed global scale. + keep_high_precision: Whether to keep scales in high precision. + + Returns: + Tuple of (per_block_scale, weights_scaling_factor_2). + """ + block_size = weight_quantizer.block_sizes[-1] + + if weights_scaling_factor_2 is None: + weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer( + weight_quantizer + ) + + if cls._is_static_quantizer(weight_quantizer): + # Static path: use pre-computed per-block amax values from quantizer + global_amax = weight_quantizer.global_amax.float() + per_block_amax = weight_quantizer._amax.float() + + # Compute scales in float + per_block_scale_max = global_amax / 6.0 + per_block_scale = per_block_amax / 6.0 + per_block_scale[per_block_scale == 0] = 1.0 + + # Reshape per_block_scale to match weight's block structure + num_blocks_per_row = weight.shape[-1] // block_size + expected_shape = (*weight.shape[:-1], num_blocks_per_row) + per_block_scale = per_block_scale.view(expected_shape) + + # Quantize scales to FP8 + if not keep_high_precision: + per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( + torch.float8_e4m3fn + ) + return per_block_scale, weights_scaling_factor_2 + else: + # Dynamic path: compute from weight tensor + return cls.get_weights_scaling_factor( + weight, block_size, weights_scaling_factor_2, keep_high_precision + ) @classmethod def get_weights_scaling_factor( @@ -67,7 +142,11 @@ def get_weights_scaling_factor( weights_scaling_factor_2: torch.Tensor | None = None, keep_high_precision: bool = False, ): - """Returns quantized per block weight scaling factor.""" + """Returns quantized per block weight scaling factor from weight tensor. + + This is the dynamic path that computes scales directly from the weight values. + For quantizers with pre-computed amax, use get_weights_scaling_factor_from_quantizer. + """ if weights_scaling_factor_2 is None: weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input)