diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbbe6ab9e..2064dc66e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. - Add PTQ support for Nemotron Parse. +- Add support for rotating the input before quantization for RHT. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..ada417375 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -934,14 +934,20 @@ def validate_calibrator(cls, v, info: ValidationInfo): assert v in ["max", "histogram"] return v - rotate: bool = ModeloptField( + rotate: bool | dict[str, bool] = ModeloptField( default=False, - title="""If rotate the input before quantization.""", - description=""""If true, the input of the quantizer will be rotated with a hadamard matrix + title="""Configuration for rotating the input before quantization.""", + description="""Can be a boolean or a dictionary with the following keys: + - "enable": Boolean to enable/disable rotation (default: False) + - "rotate_fp32": Boolean to compute rotation in float32 precision (default: False) + + If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False. + + When enabled, the input of the quantizer will be rotated with a hadamard matrix given by scipy.linalg.hadamard, i.e. ``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``. - This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant. + This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant. See https://arxiv.org/abs/2404.00456 for example.""", ) diff --git a/modelopt/torch/quantization/nn/functional.py b/modelopt/torch/quantization/nn/functional.py index df8bcbbcd..0beb7c956 100644 --- a/modelopt/torch/quantization/nn/functional.py +++ b/modelopt/torch/quantization/nn/functional.py @@ -93,7 +93,7 @@ def backward(ctx, grad_outputs): return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined] -def normalized_hadamard_transform(inputs): +def normalized_hadamard_transform(inputs, rotate_fp32=False): """Normalized fast hadamard transform.""" global fast_hadamard_transform try: @@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs): "`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`" ) - return FastHadamardTransform.apply(inputs) / torch.sqrt( + dtype = inputs.dtype + if rotate_fp32: + inputs = inputs.to(torch.float32) + outputs = FastHadamardTransform.apply(inputs) / torch.sqrt( torch.tensor(inputs.shape[-1], dtype=torch.float32) ) + return outputs.to(dtype) if rotate_fp32 else outputs diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3852d1144..937714f52 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -525,6 +525,20 @@ def is_static_block_quant(self): and self._fake_quant ) + @property + def rotate_is_enabled(self): + """Check if rotate is enabled in quant config.""" + return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate + + @property + def rotate_is_fp32(self): + """Check if rotation needs to be computed in float32.""" + return ( + self._rotate.get("rotate_fp32", False) + if isinstance(self._rotate, dict) and self.rotate_is_enabled + else False + ) + def disable_calib(self): """Disable calibration.""" self._if_calib = False @@ -992,8 +1006,8 @@ def forward(self, inputs): inputs = inputs * self.pre_quant_scale # Rotating the input - if self._rotate: - inputs = normalized_hadamard_transform(inputs) + if self.rotate_is_enabled: + inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32) if self._disabled: # if quantizer is disabled, we still need to track the input dtype for saving the model @@ -1105,7 +1119,8 @@ def extra_repr(self): if self.pre_quant_scale is not None else "" ) - s += " rotated" if self._rotate else "" + s += " rotated" if self.rotate_is_enabled else "" + s += " (fp32)" if self.rotate_is_fp32 else "" s += ( f" calibrator={self._calibrator.__class__.__name__}" if (self._calibrator is not None) diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index c768bc87e..64dd39e2c 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -41,9 +41,16 @@ def test_hadamard_transform(dim): xxt_h = x_h @ x_h.T # The numerical error can be large, especially for 16-bit floats. assert torch.allclose(xxt_h, xxt, atol=0.05) + x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True) + xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T + assert torch.allclose(xxt_h_fp32, xxt, atol=0.05) -def test_kv_rotate(): +@pytest.mark.parametrize( + "rotate_fp32", + [True, False], +) +def test_kv_rotate(rotate_fp32): mtq.plugins.register_attention_for_kv_quant(SDPAAttention) model = nn.Sequential(SDPAAttention()) mtq.replace_quant_module(model) @@ -51,11 +58,15 @@ def test_kv_rotate(): set_quantizer_by_cfg(model, {"*": {"enable": False}}) dummy_input = SDPAAttention.get_input(device="cuda") output_ref = model(dummy_input) + if rotate_fp32: + rotate = {"enable": True, "rotate_fp32": True} + else: + rotate = True with set_quantizer_by_cfg_context( model, { "*[qk]_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ): @@ -67,7 +78,7 @@ def test_kv_rotate(): model, { "*k_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ):