diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index f4c10aa..d53c14f 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,7 +14,6 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard -from importlib.metadata import version from typing import Any, Mapping # Third Party @@ -30,7 +29,6 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 TORCH_VERSION = Version(torch.__version__.split("+")[0]) -SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: @@ -243,30 +241,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) - # Check if we need CPU fallback for per-channel quantization - is_cpu = qx.device.type == "cpu" - is_per_tensor = ( - self.linear_config["weights"]["strategy"] == "tensor" - and self.linear_config["input_activations"]["strategy"] == "tensor" - ) - - # Perform mock FP8xFP8 matmul - if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: - # Check torchao version without loading the full package - if Version("0.11") < Version(version("torchao")): - raise NotImplementedError( - "Fallback path for FP8 matmul on CPU is not supported " - "on torchao > 0.11." - ) - x_dequant = qx.dequantize() - w_dequant = qweight.dequantize() - out = torch.nn.functional.linear( - x_dequant.to(w_dequant.dtype), - w_dequant, - self.bias if self.has_bias else None, - ) - return out.to(x.dtype) - # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 8de1395..f3c3602 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -50,9 +50,9 @@ def _scaled_mm_cpu_out( mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + ret = torch.addmm(bias.to(dtype=out_dtype), mat1, mat2) else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + ret = torch.mm(mat1, mat2) if out is not None: out.copy_(ret) @@ -84,9 +84,12 @@ def _scaled_mm_cpu( if torch.__version__ >= "2.8": - DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] - torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out - torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu + # In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel + # The py_kernels dictionary assignment no longer works to override native kernels + # Note: default overload is registered without the ".default" suffix + # Suppress the UserWarning about overriding a previously registered kernel + torch.library.impl("aten::_scaled_mm.out", "CPU")(_scaled_mm_cpu_out) + torch.library.impl("aten::_scaled_mm", "CPU")(_scaled_mm_cpu) else: torch.library.register_kernel( torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 13ee1a9..c300dde 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -13,13 +13,23 @@ # limitations under the License. """Test suite for FMS addon introducing FP8 functionalities""" +# Standard +import warnings + # Third Party import pytest import torch # Local from fms_mo.prep import available_packages -import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import + +# Suppress the UserWarning about overriding kernel registration in PyTorch 2.8+ +# This warning is expected when we override the native CPU kernel for _scaled_mm +warnings.simplefilter("ignore", UserWarning) +# Local +import fms_mo.aiu_addons.fp8.fp8_spyre_op # noqa: E402 # pylint: disable=unused-import,wrong-import-position + +warnings.simplefilter("default", UserWarning) # Reset to default after import # ============================================================================ # Constants @@ -146,8 +156,6 @@ def test_fp8_op() -> None: "weight_strategy,activation_strategy", [ ("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A - ("tensor", "token"), # Per-tensor W + per-token dynamic A - ("channel", "tensor"), # Per-channel W + per-tensor dynamic A ("channel", "token"), # Per-channel W + per-token dynamic A ], ) @@ -156,14 +164,11 @@ def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name activation_strategy: str, fp8_test_dimensions: dict, ) -> None: - """Test FP8Linear on CPU with different quantization strategies. + """Test FP8Linear on CPU with supported quantization strategies. This test ensures that FP8Linear works correctly on CPU with: - - Per-tensor quantization (native support in PyTorch 2.10+) - - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) - - Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel - and per-token quantization require a fallback to dequantize + regular matmul. + - Per-tensor quantization (weights and activations both per-tensor) + - Per-channel quantization (weights and activations both per-channel/per-token) Args: weight_strategy: "tensor" or "channel" weight quantization