-
Notifications
You must be signed in to change notification settings - Fork 274
Added support to rotate in fp32 (optional) #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,7 +93,7 @@ def backward(ctx, grad_outputs): | |
| return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined] | ||
|
|
||
|
|
||
| def normalized_hadamard_transform(inputs): | ||
| def normalized_hadamard_transform(inputs, rotate_fp32=False): | ||
| """Normalized fast hadamard transform.""" | ||
| global fast_hadamard_transform | ||
| try: | ||
|
|
@@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs): | |
| "`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`" | ||
| ) | ||
|
|
||
| return FastHadamardTransform.apply(inputs) / torch.sqrt( | ||
| dtype = inputs.dtype | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: When and input is FP16/BF16, the uses but the division happens in the input dtype. This is existing behavior, but there's a subtle precision difference: in FP32 mode, the division is in FP32; otherwise it's in input dtype. worth documenting if this is intentional. |
||
| if rotate_fp32: | ||
| inputs = inputs.to(torch.float32) | ||
| outputs = FastHadamardTransform.apply(inputs) / torch.sqrt( | ||
| torch.tensor(inputs.shape[-1], dtype=torch.float32) | ||
| ) | ||
| return outputs.to(dtype) if rotate_fp32 else outputs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -525,6 +525,20 @@ def is_static_block_quant(self): | |
| and self._fake_quant | ||
| ) | ||
|
|
||
| @property | ||
| def rotate_is_enabled(self): | ||
| """Check if rotate is enabled in quant config.""" | ||
| return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate | ||
|
|
||
| @property | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The check inside this property is redundant. If is True and the dict has , then will correctly return False anyway. Consider simplifying to just . |
||
| def rotate_is_fp32(self): | ||
| """Check if rotation needs to be computed in float32.""" | ||
| return ( | ||
| self._rotate.get("rotate_fp32", False) | ||
| if isinstance(self._rotate, dict) and self.rotate_is_enabled | ||
| else False | ||
| ) | ||
|
|
||
| def disable_calib(self): | ||
| """Disable calibration.""" | ||
| self._if_calib = False | ||
|
|
@@ -992,8 +1006,8 @@ def forward(self, inputs): | |
| inputs = inputs * self.pre_quant_scale | ||
|
|
||
| # Rotating the input | ||
| if self._rotate: | ||
| inputs = normalized_hadamard_transform(inputs) | ||
| if self.rotate_is_enabled: | ||
| inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32) | ||
|
|
||
| if self._disabled: | ||
| # if quantizer is disabled, we still need to track the input dtype for saving the model | ||
|
|
@@ -1105,7 +1119,8 @@ def extra_repr(self): | |
| if self.pre_quant_scale is not None | ||
| else "" | ||
| ) | ||
| s += " rotated" if self._rotate else "" | ||
| s += " rotated" if self.rotate_is_enabled else "" | ||
| s += " (fp32)" if self.rotate_is_fp32 else "" | ||
| s += ( | ||
| f" calibrator={self._calibrator.__class__.__name__}" | ||
| if (self._calibrator is not None) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,21 +41,32 @@ def test_hadamard_transform(dim): | |
| xxt_h = x_h @ x_h.T | ||
| # The numerical error can be large, especially for 16-bit floats. | ||
| assert torch.allclose(xxt_h, xxt, atol=0.05) | ||
| x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider if a tighter should be used for FP32 mode, since it should have better numerical precision than 16-bit floats. |
||
| xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T | ||
| assert torch.allclose(xxt_h_fp32, xxt, atol=0.05) | ||
|
|
||
|
|
||
| def test_kv_rotate(): | ||
| @pytest.mark.parametrize( | ||
| "rotate_fp32", | ||
| [True, False], | ||
| ) | ||
| def test_kv_rotate(rotate_fp32): | ||
| mtq.plugins.register_attention_for_kv_quant(SDPAAttention) | ||
| model = nn.Sequential(SDPAAttention()) | ||
| mtq.replace_quant_module(model) | ||
|
|
||
| set_quantizer_by_cfg(model, {"*": {"enable": False}}) | ||
| dummy_input = SDPAAttention.get_input(device="cuda") | ||
| output_ref = model(dummy_input) | ||
| if rotate_fp32: | ||
| rotate = {"enable": True, "rotate_fp32": True} | ||
| else: | ||
| rotate = True | ||
| with set_quantizer_by_cfg_context( | ||
| model, | ||
| { | ||
| "*[qk]_bmm_quantizer": { | ||
| "rotate": True, | ||
| "rotate": rotate, | ||
| }, | ||
| }, | ||
| ): | ||
|
|
@@ -67,7 +78,7 @@ def test_kv_rotate(): | |
| model, | ||
| { | ||
| "*k_bmm_quantizer": { | ||
| "rotate": True, | ||
| "rotate": rotate, | ||
| }, | ||
| }, | ||
| ): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a 0.42 feature or we should create 0.43?