Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
- Add support for image-text data calibration in PTQ for Nemotron VL models.
- Add PTQ support for Nemotron Parse.
- Add support for rotating the input before quantization for RHT.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a 0.42 feature or we should create 0.43?


0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
14 changes: 10 additions & 4 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,14 +934,20 @@ def validate_calibrator(cls, v, info: ValidationInfo):
assert v in ["max", "histogram"]
return v

rotate: bool = ModeloptField(
rotate: bool | dict[str, bool] = ModeloptField(
default=False,
title="""If rotate the input before quantization.""",
description=""""If true, the input of the quantizer will be rotated with a hadamard matrix
title="""Configuration for rotating the input before quantization.""",
description="""Can be a boolean or a dictionary with the following keys:
- "enable": Boolean to enable/disable rotation (default: False)
- "rotate_fp32": Boolean to compute rotation in float32 precision (default: False)

If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False.

When enabled, the input of the quantizer will be rotated with a hadamard matrix
given by scipy.linalg.hadamard, i.e.
``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``.

This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant.
This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant.
See https://arxiv.org/abs/2404.00456 for example.""",
)

Expand Down
8 changes: 6 additions & 2 deletions modelopt/torch/quantization/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def backward(ctx, grad_outputs):
return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined]


def normalized_hadamard_transform(inputs):
def normalized_hadamard_transform(inputs, rotate_fp32=False):
"""Normalized fast hadamard transform."""
global fast_hadamard_transform
try:
Expand All @@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs):
"`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`"
)

return FastHadamardTransform.apply(inputs) / torch.sqrt(
dtype = inputs.dtype

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: When and input is FP16/BF16, the uses but the division happens in the input dtype. This is existing behavior, but there's a subtle precision difference: in FP32 mode, the division is in FP32; otherwise it's in input dtype. worth documenting if this is intentional.

if rotate_fp32:
inputs = inputs.to(torch.float32)
outputs = FastHadamardTransform.apply(inputs) / torch.sqrt(
torch.tensor(inputs.shape[-1], dtype=torch.float32)
)
return outputs.to(dtype) if rotate_fp32 else outputs
21 changes: 18 additions & 3 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,20 @@ def is_static_block_quant(self):
and self._fake_quant
)

@property
def rotate_is_enabled(self):
"""Check if rotate is enabled in quant config."""
return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate

@property

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The check inside this property is redundant. If is True and the dict has , then will correctly return False anyway. Consider simplifying to just .

def rotate_is_fp32(self):
"""Check if rotation needs to be computed in float32."""
return (
self._rotate.get("rotate_fp32", False)
if isinstance(self._rotate, dict) and self.rotate_is_enabled
else False
)

def disable_calib(self):
"""Disable calibration."""
self._if_calib = False
Expand Down Expand Up @@ -992,8 +1006,8 @@ def forward(self, inputs):
inputs = inputs * self.pre_quant_scale

# Rotating the input
if self._rotate:
inputs = normalized_hadamard_transform(inputs)
if self.rotate_is_enabled:
inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32)

if self._disabled:
# if quantizer is disabled, we still need to track the input dtype for saving the model
Expand Down Expand Up @@ -1105,7 +1119,8 @@ def extra_repr(self):
if self.pre_quant_scale is not None
else ""
)
s += " rotated" if self._rotate else ""
s += " rotated" if self.rotate_is_enabled else ""
s += " (fp32)" if self.rotate_is_fp32 else ""
s += (
f" calibrator={self._calibrator.__class__.__name__}"
if (self._calibrator is not None)
Expand Down
17 changes: 14 additions & 3 deletions tests/gpu/torch/quantization/test_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,32 @@ def test_hadamard_transform(dim):
xxt_h = x_h @ x_h.T
# The numerical error can be large, especially for 16-bit floats.
assert torch.allclose(xxt_h, xxt, atol=0.05)
x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider if a tighter should be used for FP32 mode, since it should have better numerical precision than 16-bit floats.

xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T
assert torch.allclose(xxt_h_fp32, xxt, atol=0.05)


def test_kv_rotate():
@pytest.mark.parametrize(
"rotate_fp32",
[True, False],
)
def test_kv_rotate(rotate_fp32):
mtq.plugins.register_attention_for_kv_quant(SDPAAttention)
model = nn.Sequential(SDPAAttention())
mtq.replace_quant_module(model)

set_quantizer_by_cfg(model, {"*": {"enable": False}})
dummy_input = SDPAAttention.get_input(device="cuda")
output_ref = model(dummy_input)
if rotate_fp32:
rotate = {"enable": True, "rotate_fp32": True}
else:
rotate = True
with set_quantizer_by_cfg_context(
model,
{
"*[qk]_bmm_quantizer": {
"rotate": True,
"rotate": rotate,
},
},
):
Expand All @@ -67,7 +78,7 @@ def test_kv_rotate():
model,
{
"*k_bmm_quantizer": {
"rotate": True,
"rotate": rotate,
},
},
):
Expand Down