-
Notifications
You must be signed in to change notification settings - Fork 275
support static NVFP4 HF export #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
55c2d35
642f99b
519dc2a
e0606cb
9725c34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,7 @@ | |
| ) | ||
| from modelopt.torch.utils import clear_cuda_cache | ||
|
|
||
| from ..quantization.nn import SequentialQuantizer, TensorQuantizer | ||
| from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer | ||
| from .model_config import ( | ||
| KV_CACHE_FP8, | ||
| KV_CACHE_INT8, | ||
|
|
@@ -299,16 +299,19 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> | |
| return get_scaling_factor(weight_quantizer[0]) | ||
|
|
||
| quantization_format = get_quantization_format(module) | ||
| # If NVFP4, we need to return quantized per_block scaling factors | ||
| if quantization_format in [ | ||
|
|
||
| # Handle NVFP4 variants (static or dynamic) | ||
| is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) | ||
| if is_nvfp4_static or quantization_format in [ | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ]: | ||
| # Calibrate weight quantizer if amax is not set | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
| # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) | ||
| if not is_nvfp4_static: | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
|
|
||
| if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. | ||
|
|
@@ -318,9 +321,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> | |
| weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( | ||
| weight_quantizer | ||
| ) | ||
| return NVFP4QTensor.get_weights_scaling_factor( | ||
| # Unified method handles both static and dynamic quantizers | ||
| return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( | ||
| weight_quantizer, | ||
| weight, | ||
| weight_quantizer.block_sizes[-1], | ||
| weight_scaling_factor_2.to(weight.device), | ||
| )[0] | ||
|
|
||
|
|
@@ -343,27 +347,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") | |
|
|
||
| quantization_format = get_quantization_format(module) | ||
|
|
||
| # Calibrate weight quantizer if amax is not set for all NVFP4 variants | ||
| if quantization_format in [ | ||
| is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) | ||
| if is_nvfp4_static or quantization_format in [ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ]: | ||
| weight = getattr(module, weight_name) | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
| # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) | ||
| if not is_nvfp4_static: | ||
| weight = getattr(module, weight_name) | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
|
|
||
| if quantization_format in [ | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| ]: | ||
| return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) | ||
| elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. | ||
| # This is because the kernel dequantizes weight to fp8, which is in range 448. | ||
| return weight_quantizer._amax.float() / 448.0 | ||
| if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. | ||
| # This is because the kernel dequantizes weight to fp8, which is in range 448. | ||
| return weight_quantizer._amax.float() / 448.0 | ||
| else: | ||
| # Unified method handles both static and dynamic quantizers | ||
| return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) | ||
|
|
||
| # SequentialQuantizer is required | ||
| if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: | ||
|
|
@@ -735,7 +738,7 @@ def process_layer_quant_config(layer_config_dict): | |
| layer_config = {"quant_algo": "W8A16"} | ||
| elif v == "int8_sq": | ||
| layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} | ||
| elif v == "nvfp4": | ||
| elif v in ["nvfp4", "nvfp4_static"]: | ||
| layer_config = { | ||
| "quant_algo": "NVFP4", | ||
| "group_size": block_size_value, | ||
|
|
@@ -1339,6 +1342,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False | |
| for module in modules: | ||
| module.weight_quantizer[-1].amax = weight_amax | ||
|
|
||
| # Handle NVFP4StaticQuantizer: unify global_amax for fused layers | ||
| elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): | ||
| global_amax_list = [ | ||
| m.weight_quantizer.global_amax | ||
| for m in modules | ||
| if m.weight_quantizer.global_amax is not None | ||
| ] | ||
| if global_amax_list: | ||
| unified_global_amax = torch.max(torch.stack(global_amax_list)) | ||
| for module in modules: | ||
| module.weight_quantizer.global_amax = unified_global_amax | ||
|
|
||
| elif ( | ||
| modules[0].weight_quantizer.is_enabled | ||
| and modules[0].weight_quantizer.amax is not None | ||
|
|
@@ -1423,6 +1438,22 @@ def get_quant_config( | |
| if block_size == 0: | ||
| block_size = get_weight_block_size(module) | ||
|
|
||
| # Static NVFP4 uses pre-computed per-block scales from MSE calibration | ||
| if quantization_format == QUANTIZATION_NVFP4: | ||
| weight_quantizer = getattr(module, "weight_quantizer", None) | ||
| if weight_quantizer is None: | ||
| # Try to get from first weight attribute | ||
| for wn in weight_names: | ||
| weight_quantizer = getattr( | ||
| module, quantizer_attr_names(wn).weight_quantizer, None | ||
| ) | ||
| if weight_quantizer is not None: | ||
| break | ||
| if weight_quantizer is not None: | ||
| is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) | ||
| if is_static: | ||
| quantization_format = "nvfp4_static" | ||
|
|
||
| # Construct per layer config dictionary | ||
| layer_config_dict[name + ".quantization"] = quantization_format | ||
| layer_config_dict[name + ".awq_block_size"] = block_size | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -50,7 +50,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.distributed.fsdp import FSDPModule | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization import set_quantizer_by_cfg_context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.nn import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVFP4StaticQuantizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SequentialQuantizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorQuantizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -502,11 +506,18 @@ def _export_quantized_weight( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight, _ = maybe_transpose_expert_weight_dimensions( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight, is_bmm_expert_weight=is_bmm_expert_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = NVFP4QTensor.get_weights_scaling_factor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size=block_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_scaling_factor_2=weight_scale_2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not is_nvfp4_static: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to handle the else condition? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For dynamic NVFP4, compute scales from weights | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = NVFP4QTensor.get_weights_scaling_factor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size=block_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_scaling_factor_2=weight_scale_2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+509
to
+520
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential shape mismatch for BMM-style expert weights with static NVFP4. When Since the static path in This would fail with a shape error if static NVFP4 quantizers are ever used with Llama4TextExperts or GptOssExperts. If that combination is currently not expected, a guard would prevent a confusing error later: Proposed guard # Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
+ if is_nvfp4_static and is_bmm_expert_weight:
+ raise NotImplementedError(
+ "Static NVFP4 quantization is not yet supported for BMM-style expert weights "
+ "(Llama4TextExperts, GptOssExperts). Use dynamic NVFP4 quantization instead."
+ )
+
if not is_nvfp4_static:📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Fridah-nv could you fix this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quantized_weight = to_quantized_weight( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight.to(dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -388,6 +388,29 @@ | |
| "algorithm": "max", | ||
| } | ||
|
|
||
| NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { | ||
| "quant_cfg": { | ||
| "*weight_quantizer": { | ||
| "num_bits": (2, 1), | ||
| "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| "*input_quantizer": { | ||
| "num_bits": (2, 1), | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| **_default_disabled_quantizer_cfg, | ||
| }, | ||
| "algorithm": { | ||
| "method": "mse", | ||
| "fp8_scale_sweep": True, | ||
| }, | ||
| } | ||
|
Comment on lines
+391
to
+411
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# First, let's check the file exists and get context around lines 391-411 and 649-678
wc -l modelopt/torch/quantization/config.py
# Check if the new config exists at lines 391-411
echo "=== Lines 391-420 (new config) ==="
sed -n '391,420p' modelopt/torch/quantization/config.py
# Check the choices set around lines 649-678
echo -e "\n=== Lines 645-680 (choices set) ==="
sed -n '645,680p' modelopt/torch/quantization/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 1888 🏁 Script executed: #!/bin/bash
# Check how choices set is used in the codebase
echo "=== Usage of 'choices' in config.py ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/config.py
echo -e "\n=== References to choices in broader codebase ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/ | grep -v "config.py" | head -20
echo -e "\n=== Imports of choices from config ==="
rg -n --type=py 'from.*config.*import.*choices' modelopt/ | head -10
echo -e "\n=== Direct config.choices references ==="
rg -n --type=py 'config\.choices' modelopt/ | head -10Repository: NVIDIA/Model-Optimizer Length of output: 2136 🏁 Script executed: #!/bin/bash
# Check how the choices variable is validated/used
echo "=== Search for validation against choices ==="
rg -n --type=py -A2 -B2 'in.*choices' modelopt/torch/quantization/ | head -40
echo -e "\n=== Check if choices is exported or used at module level ==="
rg -n --type=py 'choices.*=' modelopt/torch/quantization/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 3561 🏁 Script executed: #!/bin/bash
# Check if there are comments explaining why the new config isn't in choices
echo "=== Context before the new config (lines 380-395) ==="
sed -n '380,395p' modelopt/torch/quantization/config.py
# Check if there's anything after the new config or before choices that explains the pattern
echo -e "\n=== Check around the choices set definition (lines 640-660) ==="
sed -n '640,660p' modelopt/torch/quantization/config.py
# See all NVFP4 configs to understand the pattern
echo -e "\n=== All NVFP4 config definitions (look for pattern) ==="
rg -n 'NVFP4.*_CFG = \{' modelopt/torch/quantization/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 1633 🏁 Script executed: #!/bin/bash
# Check if this is mentioned as experimental/temporary in any comments around these lines
echo "=== Comments around the new config (lines 385-415) ==="
sed -n '385,415p' modelopt/torch/quantization/config.py | grep -E '^\s*(#|""")'
# Check if choices set has any explanatory comments
echo -e "\n=== Comments around choices set definition ==="
sed -n '645,655p' modelopt/torch/quantization/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 387 Add The new config is not registered in the 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| NVFP4_AWQ_LITE_CFG = { | ||
| "quant_cfg": { | ||
| "*weight_quantizer": { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,12 +52,87 @@ def get_e2m1_bounds(cls, device): | |
| cls.e2m1_bounds_on_device[device] = e2m1_bounds.to(device) | ||
| return cls.e2m1_bounds_on_device[device] | ||
|
|
||
| @classmethod | ||
| def _is_static_quantizer(cls, weight_quantizer) -> bool: | ||
| """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" | ||
| return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None | ||
|
|
||
| @classmethod | ||
| def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a unittest? |
||
| """Returns per tensor weight scaling factor from the weight_quantizer amax.""" | ||
| # Assert that weight_quantizer has attribute amax | ||
| assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" | ||
| return weight_quantizer._amax.float() / (6.0 * 448.0) | ||
| """Returns per tensor weight scaling factor from the weight_quantizer. | ||
|
|
||
| Handles both static NVFP4 quantizers (using global_amax) and | ||
| dynamic quantizers (using _amax). | ||
|
|
||
| Args: | ||
| weight_quantizer: The weight quantizer (static or dynamic). | ||
|
|
||
| Returns: | ||
| The global scaling factor as a float tensor. | ||
| """ | ||
| if cls._is_static_quantizer(weight_quantizer): | ||
| return weight_quantizer.global_amax.float() / (6.0 * 448.0) | ||
| else: | ||
| assert hasattr(weight_quantizer, "_amax"), ( | ||
| "Weight quantizer does not have attribute amax" | ||
| ) | ||
| return weight_quantizer._amax.float() / (6.0 * 448.0) | ||
|
|
||
| @classmethod | ||
| def get_weights_scaling_factor_from_quantizer( | ||
| cls, | ||
| weight_quantizer, | ||
| weight: torch.Tensor, | ||
| weights_scaling_factor_2: torch.Tensor | None = None, | ||
| keep_high_precision: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Returns quantized per block weight scaling factor from quantizer. | ||
|
|
||
| Handles both static NVFP4 quantizers (with pre-computed per-block amax) | ||
| and dynamic quantizers (computing from weight tensor). | ||
|
|
||
| Args: | ||
| weight_quantizer: The weight quantizer (static or dynamic). | ||
| weight: The weight tensor (used for shape in static, values in dynamic). | ||
| weights_scaling_factor_2: Optional pre-computed global scale. | ||
| keep_high_precision: Whether to keep scales in high precision. | ||
|
|
||
| Returns: | ||
| Tuple of (per_block_scale, weights_scaling_factor_2). | ||
| """ | ||
| block_size = weight_quantizer.block_sizes[-1] | ||
|
|
||
| if weights_scaling_factor_2 is None: | ||
| weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer( | ||
| weight_quantizer | ||
| ) | ||
|
|
||
| if cls._is_static_quantizer(weight_quantizer): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so in static NVFP4 quant case: You use _amax for per-block amax and global_amax for the per-tensor amax? |
||
| # Static path: use pre-computed per-block amax values from quantizer | ||
| global_amax = weight_quantizer.global_amax.float() | ||
| per_block_amax = weight_quantizer._amax.float() | ||
|
|
||
| # Compute scales in float | ||
| per_block_scale_max = global_amax / 6.0 | ||
| per_block_scale = per_block_amax / 6.0 | ||
| per_block_scale[per_block_scale == 0] = 1.0 | ||
|
|
||
| # Reshape per_block_scale to match weight's block structure | ||
| num_blocks_per_row = weight.shape[-1] // block_size | ||
| expected_shape = (*weight.shape[:-1], num_blocks_per_row) | ||
| per_block_scale = per_block_scale.view(expected_shape) | ||
|
|
||
| # Quantize scales to FP8 | ||
| if not keep_high_precision: | ||
| per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( | ||
| torch.float8_e4m3fn | ||
| ) | ||
| return per_block_scale, weights_scaling_factor_2 | ||
| else: | ||
| # Dynamic path: compute from weight tensor | ||
| return cls.get_weights_scaling_factor( | ||
| weight, block_size, weights_scaling_factor_2, keep_high_precision | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_weights_scaling_factor( | ||
|
|
@@ -67,7 +142,11 @@ def get_weights_scaling_factor( | |
| weights_scaling_factor_2: torch.Tensor | None = None, | ||
| keep_high_precision: bool = False, | ||
| ): | ||
| """Returns quantized per block weight scaling factor.""" | ||
| """Returns quantized per block weight scaling factor from weight tensor. | ||
|
|
||
| This is the dynamic path that computes scales directly from the weight values. | ||
| For quantizers with pre-computed amax, use get_weights_scaling_factor_from_quantizer. | ||
| """ | ||
| if weights_scaling_factor_2 is None: | ||
| weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the quantization_format for NVFP4StaticQuantizer? do we need is_nvfp4_static here?