From 96e328464214a10cb229fb1fca52be6d35fd5f1f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 13:50:41 -0400 Subject: [PATCH 01/10] add fallback to mock fp8 matmul on cpu Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 866d6aa..33ca41a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -17,6 +17,7 @@ from typing import Any, Mapping # Third Party +from packaging.version import Version import torch # Local @@ -27,6 +28,9 @@ # torch.nn.functional.linear not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +TORCH_VERSION = Version(torch.__version__.split("+")[0]) +SUPPORTS_CPU_PER_CHANNEL_FP8 = TORCH_VERSION < Version("2.10") + # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: # Third Party @@ -213,7 +217,11 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": def forward(self, x: torch.Tensor) -> torch.Tensor: """If input quantization is active, compute FP8xFP8 addmm leveraging torchao - functionalities. Otherwise compute non-quantized addmm.""" + functionalities. Otherwise compute non-quantized addmm. + + In Pytorch 2.10, torch._scale_mm only supports FP8 on CPU when + quantization is per-tensor. In this case, we perform a mock FP8xFP8 matmul. + """ # fp8 weight tensor for torchao qweight: AffineQuantizedTensor = self._construct_qweight_structure() @@ -234,6 +242,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + # Check if we need CPU fallback for per-channel quantization + is_cpu = qx.device.type == "cpu" + is_per_tensor = ( + self.linear_config["weights"]["strategy"] == "tensor" and + self.linear_config["input_activations"]["strategy"] == "tensor" + ) + + # Perform mock FP8xFP8 matmul + if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: + x_dequant = qx.dequantize() + w_dequant = qweight.dequantize() + out = torch.nn.functional.linear( + x_dequant.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None + ) + return out.to(x.dtype) + # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) From 84d0ac04514cdb14eed691193e98c36260fe9cc3 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 14:04:49 -0400 Subject: [PATCH 02/10] formatting Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 33ca41a..e7f7f46 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -29,7 +29,7 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 TORCH_VERSION = Version(torch.__version__.split("+")[0]) -SUPPORTS_CPU_PER_CHANNEL_FP8 = TORCH_VERSION < Version("2.10") +SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: @@ -245,8 +245,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if we need CPU fallback for per-channel quantization is_cpu = qx.device.type == "cpu" is_per_tensor = ( - self.linear_config["weights"]["strategy"] == "tensor" and - self.linear_config["input_activations"]["strategy"] == "tensor" + self.linear_config["weights"]["strategy"] == "tensor" + and self.linear_config["input_activations"]["strategy"] == "tensor" ) # Perform mock FP8xFP8 matmul @@ -254,7 +254,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_dequant = qx.dequantize() w_dequant = qweight.dequantize() out = torch.nn.functional.linear( - x_dequant.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None + x_dequant.to(w_dequant.dtype), + w_dequant, + self.bias if self.has_bias else None, ) return out.to(x.dtype) From 23b10a3deea9f4bfb043aa9bf1cd8fa1d7b0848f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 19:51:37 -0400 Subject: [PATCH 03/10] fix dtype of fp8 matmul with non-quantized activations Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index e7f7f46..c89b17e 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -302,10 +302,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).reshape(out_shape) # activations not quantized, dequant fp8 weight and do regular matmul + w_dequant = qweight.dequantize() out = torch.nn.functional.linear( - x, qweight.dequantize(), self.bias if self.has_bias else None + x.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None ) - return out + return out.to(x.dtype) def __repr__(self) -> str: return ( From 8a046a92fca806a4640fa68559559e1cf8aa9d58 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 20:17:39 -0400 Subject: [PATCH 04/10] add torchao version check in fallback Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index c89b17e..f4f6ac4 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,6 +14,7 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard +from importlib.metadata import version from typing import Any, Mapping # Third Party @@ -33,6 +34,8 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: + TORCHAO_VERSION = Version(version("torchao")) + # Third Party from fms.modules.linear import ( LinearModuleShardingInfo, @@ -251,6 +254,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Perform mock FP8xFP8 matmul if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: + if Version("0.11") < TORCHAO_VERSION: + raise NotImplementedError( + "Fallback path for FP8 matmul on CPU is not supported " + "on torchao > 0.11." + ) x_dequant = qx.dequantize() w_dequant = qweight.dequantize() out = torch.nn.functional.linear( From c0a617f4c1344ec2a8d44c1defd49942384aa38b Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 20:57:02 -0400 Subject: [PATCH 05/10] add unit tests for FP8 matmul on CPU Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 263 ++++++++++++++++++++++++++++- 1 file changed, 260 insertions(+), 3 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index a382c63..875b6f1 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -21,6 +21,117 @@ from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def initialize_fp8_weights( + fp8_linear, + weight_strategy: str, + in_features: int, + out_features: int, +) -> None: + """Initialize FP8Linear weights with proper absmax scaling. + + Args: + fp8_linear: FP8Linear module to initialize + weight_strategy: "tensor" or "channel" for weight quantization + in_features: Input feature dimension + out_features: Output feature dimension + """ + with torch.no_grad(): + # Create random float weights + float_weights = torch.randn(out_features, in_features) + + # Calculate FP8 E4M3 max value (448.0) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + # Set appropriate scales based on strategy using absmax + if weight_strategy == "tensor": + # Per-tensor: single scale for entire weight matrix + absmax = float_weights.abs().max() + scale = absmax / fp8_max + # Ensure scale is not zero + scale = torch.clamp(scale, min=1e-12) + fp8_linear.weight_scale.fill_(scale.item()) + else: # channel (per-row for weight matrix) + # Per-channel: one scale per output channel (row) + absmax = float_weights.abs().amax(dim=1) + scale = absmax / fp8_max + # Ensure scales are not zero + scale = torch.clamp(scale, min=1e-12) + # Reshape to match weight_scale parameter shape (out_features, 1) + fp8_linear.weight_scale.copy_(scale.reshape(-1, 1)) + + # Quantize weights to FP8 + quantized_weights = (float_weights / fp8_linear.weight_scale).clamp( + -fp8_max, fp8_max + ) + fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn)) + + # Initialize bias if present + if fp8_linear.has_bias: + fp8_linear.bias.copy_(torch.randn(out_features)) + + +def initialize_fp8_input_scale( + fp8_linear, + activation_strategy: str, + batch_size: int, + seq_len: int, + in_features: int, +) -> None: + """Initialize static input scale for FP8Linear. + + Args: + fp8_linear: FP8Linear module to initialize + activation_strategy: "tensor" or "token" for activation quantization + batch_size: Batch size for sample input + seq_len: Sequence length for sample input + in_features: Input feature dimension + """ + with torch.no_grad(): + # For static quantization, use a representative input to calculate scales + sample_input = torch.randn(batch_size, seq_len, in_features) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + if activation_strategy == "tensor": + # Per-tensor: single scale for entire activation + absmax = sample_input.abs().max() + scale = absmax / fp8_max + scale = torch.clamp(scale, min=1e-12) + fp8_linear.input_scale.fill_(scale.item()) + else: # token + # For per-token static quantization, use a calibrated scale + # based on representative input statistics + absmax = sample_input.abs().max() + scale = absmax / fp8_max + scale = torch.clamp(scale, min=1e-12) + # Fill all scales with the same representative value + fp8_linear.input_scale.fill_(scale.item()) + + +# ============================================================================ +# Pytest Fixtures +# ============================================================================ + + +@pytest.fixture +def fp8_test_dimensions(): + """Common test dimensions for FP8Linear tests.""" + return { + "batch_size": 2, + "seq_len": 4, + "in_features": 8, + "out_features": 16, + } + + +# ============================================================================ +# Tests +# ============================================================================ + def test_fp8_registration() -> None: """ @@ -44,9 +155,10 @@ def test_fp8_registration() -> None: reason="FP8 is only available on GPUs with device level 8.9 or higher", ) def test_fp8_op() -> None: - """Validate output shapes of GPTQ W4A16 tensors. - Note: this AIU-compatible operation only returns a zero tensor of the - expected shape, it does not perform a real W4A16 matmul operation. + """Validate output shapes of FP8 attention operation. + + Tests the FP8 attention compute operation to ensure it produces + outputs with the expected shape. """ # Local from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op @@ -57,3 +169,148 @@ def test_fp8_op() -> None: out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) assert out.size() == query.size() + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +@pytest.mark.parametrize( + "weight_strategy,activation_strategy,dynamic_activation", + [ + ("tensor", "tensor", True), # Per-tensor weights + per-tensor activations + ("tensor", "token", True), # Per-tensor weights + per-token activations + ("channel", "tensor", True), # Per-channel weights + per-tensor activations + ("channel", "token", True), # Per-channel weights + per-token activations + ], +) +def test_fp8_linear_cpu_support( + weight_strategy: str, + activation_strategy: str, + dynamic_activation: bool, + fp8_test_dimensions: dict, +) -> None: + """Test FP8Linear on CPU with different quantization strategies. + + This test ensures that FP8Linear works correctly on CPU, including: + - Per-tensor quantization (native support in PyTorch 2.10+) + - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) + + Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel + and per-token quantization require a fallback to dequantize + regular matmul. + + Args: + weight_strategy: "tensor" or "channel" for weight quantization + activation_strategy: "tensor" or "token" for activation quantization + dynamic_activation: Whether to use dynamic activation quantization + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration + linear_config = { + "weights": { + "strategy": weight_strategy, + "symmetric": True, + "dynamic": False, + }, + "input_activations": { + "strategy": activation_strategy, + "symmetric": True, + "dynamic": dynamic_activation, + }, + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features) + + # Initialize input scale if static quantization + if not dynamic_activation: + initialize_fp8_input_scale( + fp8_linear, activation_strategy, batch_size, seq_len, in_features + ) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass - should not raise an error + output = fp8_linear(x) + + # Validate output shape + assert output.shape == (batch_size, seq_len, out_features) + + # Validate output is not NaN or Inf + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + # Validate output dtype matches input dtype + assert output.dtype == x.dtype + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: + """Test FP8Linear on CPU with only weight quantization (no activation quantization). + + This tests the code path where activations are not quantized but weights are FP8. + + Args: + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration with no activation quantization + linear_config = { + "weights": { + "strategy": "channel", + "symmetric": True, + "dynamic": False, + }, + "input_activations": None, # No activation quantization + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, "channel", in_features, out_features) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass + output = fp8_linear(x) + + # Validate output + assert output.shape == (batch_size, seq_len, out_features) + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() From c28329182374a1e789448217417efc1631e028a5 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:05:18 -0400 Subject: [PATCH 06/10] minor updates Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 4 ++-- tests/aiu_addons/test_fp8_addon.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index f4f6ac4..a9b3543 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -34,7 +34,6 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: - TORCHAO_VERSION = Version(version("torchao")) # Third Party from fms.modules.linear import ( @@ -254,7 +253,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Perform mock FP8xFP8 matmul if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: - if Version("0.11") < TORCHAO_VERSION: + # Check torchao version without loading the full package + if Version("0.11") < Version(version("torchao")): raise NotImplementedError( "Fallback path for FP8 matmul on CPU is not supported " "on torchao > 0.11." diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 875b6f1..f989e1d 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -21,6 +21,13 @@ from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# ============================================================================ +# Constants +# ============================================================================ + +# FP8 E4M3 maximum value +FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + # ============================================================================ # Helper Functions # ============================================================================ @@ -44,21 +51,18 @@ def initialize_fp8_weights( # Create random float weights float_weights = torch.randn(out_features, in_features) - # Calculate FP8 E4M3 max value (448.0) - fp8_max = torch.finfo(torch.float8_e4m3fn).max - # Set appropriate scales based on strategy using absmax if weight_strategy == "tensor": # Per-tensor: single scale for entire weight matrix absmax = float_weights.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX # Ensure scale is not zero scale = torch.clamp(scale, min=1e-12) fp8_linear.weight_scale.fill_(scale.item()) else: # channel (per-row for weight matrix) # Per-channel: one scale per output channel (row) absmax = float_weights.abs().amax(dim=1) - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX # Ensure scales are not zero scale = torch.clamp(scale, min=1e-12) # Reshape to match weight_scale parameter shape (out_features, 1) @@ -66,7 +70,7 @@ def initialize_fp8_weights( # Quantize weights to FP8 quantized_weights = (float_weights / fp8_linear.weight_scale).clamp( - -fp8_max, fp8_max + -FP8_E4M3_MAX, FP8_E4M3_MAX ) fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn)) @@ -94,19 +98,18 @@ def initialize_fp8_input_scale( with torch.no_grad(): # For static quantization, use a representative input to calculate scales sample_input = torch.randn(batch_size, seq_len, in_features) - fp8_max = torch.finfo(torch.float8_e4m3fn).max if activation_strategy == "tensor": # Per-tensor: single scale for entire activation absmax = sample_input.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX scale = torch.clamp(scale, min=1e-12) fp8_linear.input_scale.fill_(scale.item()) else: # token # For per-token static quantization, use a calibrated scale # based on representative input statistics absmax = sample_input.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX scale = torch.clamp(scale, min=1e-12) # Fill all scales with the same representative value fp8_linear.input_scale.fill_(scale.item()) From e158bd26042942332fd72401165763e260ac6861 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:05:46 -0400 Subject: [PATCH 07/10] minor updates Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index a9b3543..f4c10aa 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -34,7 +34,6 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: - # Third Party from fms.modules.linear import ( LinearModuleShardingInfo, From ef7357615304ccdec39434e66e6cc62c0d1f3a89 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:12:58 -0400 Subject: [PATCH 08/10] remove static activation test Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 62 +++++------------------------- 1 file changed, 9 insertions(+), 53 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index f989e1d..714bd0a 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -79,42 +79,6 @@ def initialize_fp8_weights( fp8_linear.bias.copy_(torch.randn(out_features)) -def initialize_fp8_input_scale( - fp8_linear, - activation_strategy: str, - batch_size: int, - seq_len: int, - in_features: int, -) -> None: - """Initialize static input scale for FP8Linear. - - Args: - fp8_linear: FP8Linear module to initialize - activation_strategy: "tensor" or "token" for activation quantization - batch_size: Batch size for sample input - seq_len: Sequence length for sample input - in_features: Input feature dimension - """ - with torch.no_grad(): - # For static quantization, use a representative input to calculate scales - sample_input = torch.randn(batch_size, seq_len, in_features) - - if activation_strategy == "tensor": - # Per-tensor: single scale for entire activation - absmax = sample_input.abs().max() - scale = absmax / FP8_E4M3_MAX - scale = torch.clamp(scale, min=1e-12) - fp8_linear.input_scale.fill_(scale.item()) - else: # token - # For per-token static quantization, use a calibrated scale - # based on representative input statistics - absmax = sample_input.abs().max() - scale = absmax / FP8_E4M3_MAX - scale = torch.clamp(scale, min=1e-12) - # Fill all scales with the same representative value - fp8_linear.input_scale.fill_(scale.item()) - - # ============================================================================ # Pytest Fixtures # ============================================================================ @@ -179,23 +143,22 @@ def test_fp8_op() -> None: reason="FMS and torchao required to run this test", ) @pytest.mark.parametrize( - "weight_strategy,activation_strategy,dynamic_activation", + "weight_strategy,activation_strategy", [ - ("tensor", "tensor", True), # Per-tensor weights + per-tensor activations - ("tensor", "token", True), # Per-tensor weights + per-token activations - ("channel", "tensor", True), # Per-channel weights + per-tensor activations - ("channel", "token", True), # Per-channel weights + per-token activations + ("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A + ("tensor", "token"), # Per-tensor W + per-token dynamic A + ("channel", "tensor"), # Per-channel W + per-tensor dynamic A + ("channel", "token"), # Per-channel W + per-token dynamic A ], ) def test_fp8_linear_cpu_support( weight_strategy: str, activation_strategy: str, - dynamic_activation: bool, fp8_test_dimensions: dict, ) -> None: """Test FP8Linear on CPU with different quantization strategies. - This test ensures that FP8Linear works correctly on CPU, including: + This test ensures that FP8Linear works correctly on CPU with: - Per-tensor quantization (native support in PyTorch 2.10+) - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) @@ -203,9 +166,8 @@ def test_fp8_linear_cpu_support( and per-token quantization require a fallback to dequantize + regular matmul. Args: - weight_strategy: "tensor" or "channel" for weight quantization - activation_strategy: "tensor" or "token" for activation quantization - dynamic_activation: Whether to use dynamic activation quantization + weight_strategy: "tensor" or "channel" weight quantization + activation_strategy: "tensor" or "token" dynamic activation quantization fp8_test_dimensions: Test dimensions fixture """ # Local @@ -227,7 +189,7 @@ def test_fp8_linear_cpu_support( "input_activations": { "strategy": activation_strategy, "symmetric": True, - "dynamic": dynamic_activation, + "dynamic": True, }, } @@ -242,12 +204,6 @@ def test_fp8_linear_cpu_support( # Initialize weights using helper function initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features) - # Initialize input scale if static quantization - if not dynamic_activation: - initialize_fp8_input_scale( - fp8_linear, activation_strategy, batch_size, seq_len, in_features - ) - # Create input tensor on CPU x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) From bc5de3ee238f8d54c9a83a1b3212ae5f8ed5410f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:21:58 -0400 Subject: [PATCH 09/10] fix pylint false positive Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 714bd0a..13ee1a9 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -151,7 +151,7 @@ def test_fp8_op() -> None: ("channel", "token"), # Per-channel W + per-token dynamic A ], ) -def test_fp8_linear_cpu_support( +def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name weight_strategy: str, activation_strategy: str, fp8_test_dimensions: dict, @@ -225,7 +225,7 @@ def test_fp8_linear_cpu_support( not available_packages["torchao"] or not available_packages["fms"], reason="FMS and torchao required to run this test", ) -def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: +def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: # pylint: disable=redefined-outer-name """Test FP8Linear on CPU with only weight quantization (no activation quantization). This tests the code path where activations are not quantized but weights are FP8. From 1558c3f99fb567600d7d490b25490dd4c5ddd04c Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 20 Mar 2026 14:55:53 -0400 Subject: [PATCH 10/10] add warning note to pyproject Signed-off-by: Andrea Fasoli --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b1cf6d8..8381b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor", "torchao==0.11"] +fp8 = ["llmcompressor", "torchao==0.11"] # FP8 matmul on CPU needs a fix before advancing torchao > 0.11 gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"]