diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..36f460172d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -489,3 +489,43 @@ def test_nvfp4_quantization_noncontiguous_inputs( torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_nvfp4_3d_shape_quantization( + M: int, + N: int, + with_2d_quantization: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + # Input + x = torch.randn((M, 4, N), dtype=torch.bfloat16, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + q_x = nvfp4_quantizer(x) + x *= 2 + nvfp4_quantizer.update_quantized(x, q_x) + assert q_x._rowwise_data is not None + assert len(q_x._rowwise_data.shape) == 3 + assert q_x._columnwise_data is not None + assert len(q_x._columnwise_data.shape) == 2 diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..61e1066828 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -6,6 +6,9 @@ #include +#include +#include + #include "common.h" #include "pybind.h" #include "torch/torch.h" @@ -1260,6 +1263,29 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w return {std::move(out_cpp), std::move(out_py)}; } +/** + * @brief Compress an N-D shape into a 2-D shape by flattening all but the last dimension. + * + * This utility is intended for comparing N-dimensional tensor shapes in a 2D space: + * it multiplies (flattens) every dimension except the final one into a single leading + * dimension, and keeps the last dimension unchanged. + * + * Example: [d0, d1, d2, ..., d{n-2}, d{n-1}] -> [d0*d1*...*d{n-2}, d{n-1}] + * + * If the input has 2 or fewer dimensions, it is returned unchanged. + */ +std::vector compressShapeTo2D(const std::vector& data) { + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate(data.begin(), data.end() - 1, static_cast(1), + std::multiplies()); + // Return new vector of size 2: {product, last} + return std::vector{product, data.back()}; +} + std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); @@ -1289,8 +1315,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); + auto expected_shape_2d = compressShapeTo2D(expected_shape); + auto shape_2d = compressShapeTo2D(shape); + NVTE_CHECK(shape_2d == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, + ") and column-wise data (2D shape=", shape_2d, ") do not match"); + shape = expected_shape; } } else { // Already checked columnwise_data_tensor == true shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);