Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 69 additions & 90 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator"]
__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand All @@ -39,7 +39,6 @@ def __init__(
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
fp8_scale_sweep: bool = False,
):
"""Initialize MSE calibrator.

Expand All @@ -54,9 +53,6 @@ def __init__(
Should have signature: quant_func(x, amax) -> quantized_x.
error_func: Function to compute error between x and xq.
Default is F.mse_loss(x, xq, reduction='none').
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
instead of using multipliers. This is specifically for NVFP4
per-block quantization where scales are stored in FP8 format.
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
Expand All @@ -67,17 +63,21 @@ def __init__(

self._quant_func = quant_func
self._error_func = error_func
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._fp8_scale_sweep = fp8_scale_sweep
if fp8_scale_sweep:
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
self._num_steps = 126
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None
self._losses_sum: list[torch.Tensor | None] | None = None
self._candidates: torch.Tensor | None = None
self._amax: torch.Tensor | None = None

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate candidate multipliers. Override in subclasses for different candidate sets."""
return torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
"""Compute amax from candidates. Override in subclasses for different amax computation."""
if candidates.ndim != 0: # Called during final compute amax
candidates = candidates.view_as(self._initial_amax)
return self._initial_amax * candidates

@torch.no_grad()
def collect(self, x: torch.Tensor):
Expand All @@ -87,39 +87,22 @@ def collect(self, x: torch.Tensor):
x: Input tensor.
"""
if self._quant_func is None:
raise RuntimeError(
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
)
raise RuntimeError("Quantization function not set.")

x = x.detach().to(dtype=torch.float32)

device = x.device

if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)

# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()

# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values_valid = fp8_values[valid_mask]
candidates = self._generate_candidates(device)
if self._candidates is None:
self._candidates = candidates
self._num_steps = len(candidates)
self._losses_sum = [None] * self._num_steps

candidates = fp8_values_valid / 448.0
else:
candidates = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
# Get reduce axis for per-channel quantization
assert self._losses_sum is not None
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax)
else:
candidate_amax = self._initial_amax * candidate
candidate_amax = self._compute_candidate_amax(candidate)
xq = self._quant_func(x, candidate_amax)

if self._error_func is not None:
Expand All @@ -129,28 +112,16 @@ def collect(self, x: torch.Tensor):

loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)

if self._candidate_amaxs[step] is None:
self._candidate_amaxs[step] = candidate_amax

if self._losses_sum[step] is None:
self._losses_sum[step] = loss.clone()
else:
self._losses_sum[step] += loss

def reset(self):
"""Reset the stored losses and amax value."""
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._losses_sum = None
self._candidates = None
self._amax = None

def clear(self):
"""Clear all cached data to free GPU memory.

Call this after compute_amax() and load_calib_amax() are done.
"""
self._losses_sum = []
self._candidate_amaxs = []

if self._initial_amax is not None:
del self._initial_amax
self._initial_amax = None
Expand All @@ -162,49 +133,28 @@ def compute_amax(self, verbose: bool = False):
Args:
verbose: If True, print the ratio of best_amax to initial_amax.
"""
if not any(loss_sum is not None for loss_sum in self._losses_sum):
if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum):
return None

# Check if this is per-tensor or per-channel based on the first loss
first_loss_sum = None
for loss_sum in self._losses_sum:
if loss_sum is not None:
first_loss_sum = loss_sum
break

if first_loss_sum is None:
first_loss = next((loss for loss in self._losses_sum if loss is not None), None)
if first_loss is None:
return None

# Collect losses for all steps
losses_per_step = []
# Stack losses: [num_steps] or [num_steps, num_channels]
losses = []
for step in range(self._num_steps):
if self._losses_sum[step] is not None:
losses_per_step.append(self._losses_sum[step])
# No data for this step, use inf
elif first_loss_sum.ndim == 0:
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
losses.append(self._losses_sum[step])
elif first_loss.ndim == 0:
losses.append(torch.tensor(float("inf"), device=first_loss.device))
else:
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))

# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
losses_per_step = torch.stack(losses_per_step)
losses.append(torch.full_like(first_loss, float("inf")))

# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
best_steps = torch.argmin(losses_per_step, dim=0)

# Stack candidate amaxs and select based on best_steps
candidate_amaxs = torch.stack(self._candidate_amaxs)

if first_loss_sum.ndim == 0:
# Per-tensor case: best_steps is a scalar
self._amax = self._candidate_amaxs[best_steps.item()]
else:
# Per-channel case: best_steps is a tensor
num_channels = best_steps.shape[0]
self._amax = candidate_amaxs[
best_steps, torch.arange(num_channels, device=best_steps.device)
]
self._amax = self._amax.reshape(self._initial_amax.shape)
losses = torch.stack(losses)
best_indices = torch.argmin(losses, dim=0)
assert self._candidates is not None
best_candidates = self._candidates[best_indices]
self._amax = self._compute_candidate_amax(best_candidates)

if verbose:
ratio = self._amax / self._initial_amax
Expand All @@ -219,3 +169,32 @@ def compute_amax(self, verbose: bool = False):
)

return self._amax


class NVFP4MSECalibrator(MseCalibrator):
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""

def __init__(
self,
amax: torch.Tensor, # per_block_amax shape [num_blocks]
global_amax: torch.Tensor, # scalar
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
candidates = candidates.view_as(self._initial_amax)
return torch.ones_like(self._initial_amax) * self._global_amax * candidates

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0
16 changes: 16 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,22 @@ class SVDQuantConfig(QuantizeAlgorithmConfig):
)


class ScaleAfterDequantConfig(QuantizeAlgorithmConfig):
"""Config for scale-after-dequant algorithm.

Runs MSE+FP8 scale sweep calibration, then converts NVFP4 quantizers to
learnable-amax mode for fine-tuning.
"""

method: Literal["scale_after_dequant"] = ModeloptField("scale_after_dequant")

scale_algorithm: dict | None = ModeloptField(
default=None,
title="Scale calibration algorithm to run first.",
description="Must be {'method': 'mse', 'fp8_scale_sweep': True}. Defaults to that if None.",
)


class GPTQLiteConfig(QuantizeAlgorithmConfig):
"""The config for GPTQ lite.

Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_QuantizeExportConfig,
)
from .nn import (
NVFP4StaticQuantizer,
QuantModule,
QuantModuleRegistry,
SequentialQuantizer,
Expand Down Expand Up @@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer):
name = get_unwrapped_name(name, model)
state = quantizer_state_dict[name]
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
module, NVFP4StaticQuantizer
):
NVFP4StaticQuantizer.from_tensor_quantizer(module)
module.set_from_modelopt_state(quantizer_state_dict[name])

for name, module in model.named_modules():
Expand Down
23 changes: 22 additions & 1 deletion modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
QuantizeAlgoCfgType,
QuantizeAlgorithmConfig,
QuantizeConfig,
ScaleAfterDequantConfig,
SmoothQuantCalibConfig,
SVDQuantConfig,
_QuantizeExportConfig,
Expand All @@ -56,7 +57,15 @@
restore_svdquant_model,
update_quantize_metadata,
)
from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant
from .model_calib import (
awq,
gptq_lite,
max_calibrate,
mse_calibrate,
scale_after_dequant,
smoothquant,
svdquant,
)

__all__ = ["BaseCalibrateModeDescriptor"]

Expand Down Expand Up @@ -452,3 +461,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
return GPTQLiteConfig

_calib_func = gptq_lite


@CalibrateModeRegistry.register_mode
class ScaleAfterDequantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for scale-after-dequant algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return ScaleAfterDequantConfig

_calib_func = scale_after_dequant
Loading
Loading