From ca63c7e2663a56667f7177f408775b677f90fae9 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 26 Mar 2026 21:25:45 +0000 Subject: [PATCH 1/7] fix _scaled_mm cpu override syntax Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 24 ------------------------ fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 9 +++++---- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index f4c10aa..b808661 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -243,30 +243,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) - # Check if we need CPU fallback for per-channel quantization - is_cpu = qx.device.type == "cpu" - is_per_tensor = ( - self.linear_config["weights"]["strategy"] == "tensor" - and self.linear_config["input_activations"]["strategy"] == "tensor" - ) - - # Perform mock FP8xFP8 matmul - if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: - # Check torchao version without loading the full package - if Version("0.11") < Version(version("torchao")): - raise NotImplementedError( - "Fallback path for FP8 matmul on CPU is not supported " - "on torchao > 0.11." - ) - x_dequant = qx.dequantize() - w_dequant = qweight.dequantize() - out = torch.nn.functional.linear( - x_dequant.to(w_dequant.dtype), - w_dequant, - self.bias if self.has_bias else None, - ) - return out.to(x.dtype) - # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 8de1395..b1b029d 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -82,11 +82,12 @@ def _scaled_mm_cpu( out=None, ) - if torch.__version__ >= "2.8": - DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] - torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out - torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu + # In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel + # The py_kernels dictionary assignment no longer works to override native kernels + # Note: default overload is registered without the ".default" suffix + torch.library.impl("aten::_scaled_mm.out", "CPU")(_scaled_mm_cpu_out) + torch.library.impl("aten::_scaled_mm", "CPU")(_scaled_mm_cpu) else: torch.library.register_kernel( torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out From ce20be8a9c2b23f809e098a575820af8a5a462bd Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 26 Mar 2026 21:54:02 +0000 Subject: [PATCH 2/7] lint Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 1 - fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index b808661..d555994 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,7 +14,6 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard -from importlib.metadata import version from typing import Any, Mapping # Third Party diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index b1b029d..1858d09 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -82,6 +82,7 @@ def _scaled_mm_cpu( out=None, ) + if torch.__version__ >= "2.8": # In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel # The py_kernels dictionary assignment no longer works to override native kernels From 01d0cfe43cf28a730c015da2f07fbce233a9c4da Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 26 Mar 2026 22:30:49 +0000 Subject: [PATCH 3/7] add warning suppression for op override Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 13ee1a9..89b34cd 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -13,13 +13,21 @@ # limitations under the License. """Test suite for FMS addon introducing FP8 functionalities""" +# Standard +import warnings + # Third Party import pytest import torch # Local from fms_mo.prep import available_packages + +# Suppress the UserWarning about overriding kernel registration in PyTorch 2.8+ +# This warning is expected when we override the native CPU kernel for _scaled_mm +warnings.simplefilter("ignore", UserWarning) import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +warnings.simplefilter("default", UserWarning) # Reset to default after import # ============================================================================ # Constants From 7fa142bcdfa18cad2859cb01c0820c0cdecf868c Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Mar 2026 00:03:26 +0000 Subject: [PATCH 4/7] remove tests for unsupported fp8 setup Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 6 ++++-- tests/aiu_addons/test_fp8_addon.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 1858d09..54b4a61 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -50,9 +50,10 @@ def _scaled_mm_cpu_out( mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + bias_converted = bias.to(dtype=out_dtype) + ret = torch.addmm(bias_converted, mat1, mat2) else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + ret = torch.mm(mat1, mat2) if out is not None: out.copy_(ret) @@ -87,6 +88,7 @@ def _scaled_mm_cpu( # In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel # The py_kernels dictionary assignment no longer works to override native kernels # Note: default overload is registered without the ".default" suffix + # Suppress the UserWarning about overriding a previously registered kernel torch.library.impl("aten::_scaled_mm.out", "CPU")(_scaled_mm_cpu_out) torch.library.impl("aten::_scaled_mm", "CPU")(_scaled_mm_cpu) else: diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 89b34cd..81f641a 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -26,7 +26,9 @@ # Suppress the UserWarning about overriding kernel registration in PyTorch 2.8+ # This warning is expected when we override the native CPU kernel for _scaled_mm warnings.simplefilter("ignore", UserWarning) -import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# Local +import fms_mo.aiu_addons.fp8.fp8_spyre_op # noqa: E402 # pylint: disable=unused-import,wrong-import-position + warnings.simplefilter("default", UserWarning) # Reset to default after import # ============================================================================ @@ -154,8 +156,6 @@ def test_fp8_op() -> None: "weight_strategy,activation_strategy", [ ("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A - ("tensor", "token"), # Per-tensor W + per-token dynamic A - ("channel", "tensor"), # Per-channel W + per-tensor dynamic A ("channel", "token"), # Per-channel W + per-token dynamic A ], ) @@ -164,14 +164,14 @@ def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name activation_strategy: str, fp8_test_dimensions: dict, ) -> None: - """Test FP8Linear on CPU with different quantization strategies. + """Test FP8Linear on CPU with supported quantization strategies. This test ensures that FP8Linear works correctly on CPU with: - - Per-tensor quantization (native support in PyTorch 2.10+) - - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) + - Per-tensor quantization (weights and activations both per-tensor) + - Per-channel quantization (weights and activations both per-channel/per-token) - Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel - and per-token quantization require a fallback to dequantize + regular matmul. + Note: Mixed granularity (e.g., per-tensor weights with per-token activations) + is not supported on the target custom hardware. Args: weight_strategy: "tensor" or "channel" weight quantization From 2fe51f124770d956868d9db126835bfdd74cc1b0 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Mar 2026 00:31:01 +0000 Subject: [PATCH 5/7] trim docstring Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 81f641a..c300dde 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -170,9 +170,6 @@ def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name - Per-tensor quantization (weights and activations both per-tensor) - Per-channel quantization (weights and activations both per-channel/per-token) - Note: Mixed granularity (e.g., per-tensor weights with per-token activations) - is not supported on the target custom hardware. - Args: weight_strategy: "tensor" or "channel" weight quantization activation_strategy: "tensor" or "token" dynamic activation quantization From ba635b139da38e1fd864a41a875584f505e15cb8 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Mar 2026 00:37:41 +0000 Subject: [PATCH 6/7] minor formatting Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 54b4a61..f3c3602 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -50,8 +50,7 @@ def _scaled_mm_cpu_out( mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) if bias is not None: - bias_converted = bias.to(dtype=out_dtype) - ret = torch.addmm(bias_converted, mat1, mat2) + ret = torch.addmm(bias.to(dtype=out_dtype), mat1, mat2) else: ret = torch.mm(mat1, mat2) From 1af819a41d98504c825ea593371aae1b5f22ff2e Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Mar 2026 01:27:58 +0000 Subject: [PATCH 7/7] removed unused constant Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index d555994..d53c14f 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -29,7 +29,6 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 TORCH_VERSION = Version(torch.__version__.split("+")[0]) -SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: