From a35de590341a6e181972ffed27e4395ad05c540d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=99=E5=88=92?= Date: Tue, 10 Feb 2026 16:57:38 +0800 Subject: [PATCH 1/7] fix nvfp4 convert_and_update_tensor shape check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 乙划 --- transformer_engine/pytorch/csrc/quantizer.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..6a75c5559b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1260,6 +1260,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w return {std::move(out_cpp), std::move(out_py)}; } +std::vector compressShapeTo2D(const std::vector& data) { + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate( + data.begin(), + data.end() - 1, + static_cast(1), + std::multiplies() + ); + // Return new vector of size 2: {product, last} + return std::vector{ product, data.back() }; +} + std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); @@ -1289,8 +1305,10 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + auto expected_shape_2d = compressShapeTo2D(expected_shape); + NVTE_CHECK(shape == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, ") and column-wise data (shape=", shape, ") do not match"); + shape = expected_shape; } } else { // Already checked columnwise_data_tensor == true shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); From c7eaa0571127800066b03b2660f333e8144ed216 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 09:07:39 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6a75c5559b..2a13a01dbb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1261,19 +1261,15 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w } std::vector compressShapeTo2D(const std::vector& data) { - // If 2 or fewer elements, return as-is - if (data.size() <= 2) { - return data; - } - // Multiply all elements except the last - size_t product = std::accumulate( - data.begin(), - data.end() - 1, - static_cast(1), - std::multiplies() - ); - // Return new vector of size 2: {product, last} - return std::vector{ product, data.back() }; + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate(data.begin(), data.end() - 1, static_cast(1), + std::multiplies()); + // Return new vector of size 2: {product, last} + return std::vector{product, data.back()}; } std::pair NVFP4Quantizer::convert_and_update_tensor( From c0a8aa6c6c442f0a3a2c41aa2482a865ea7b20dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=99=E5=88=92?= Date: Wed, 11 Feb 2026 13:44:57 +0800 Subject: [PATCH 3/7] add headers and check 2D shapes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 乙划 --- transformer_engine/pytorch/csrc/quantizer.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6a75c5559b..e4ab10607c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include #include #include "common.h" @@ -1306,8 +1308,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (rowwise_data) { auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); auto expected_shape_2d = compressShapeTo2D(expected_shape); + auto shape_2d = compressShapeTo2D(shape); NVTE_CHECK(shape == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, - ") and column-wise data (shape=", shape, ") do not match"); + ") and column-wise data (2D shape=", shape_2d, ") do not match"); shape = expected_shape; } } else { // Already checked columnwise_data_tensor == true From 489f060a0498ff5b2a8dc08b71b8cb571a408ecf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 05:51:11 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 00d978bd3b..30ae2e38c2 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -4,10 +4,11 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include #include +#include +#include + #include "common.h" #include "pybind.h" #include "torch/torch.h" From 11fc532e6e265342732beaeabc2b0c7866802863 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 11 Feb 2026 09:59:54 -0800 Subject: [PATCH 5/7] Fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 30ae2e38c2..d61564ddfc 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1306,7 +1306,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); auto expected_shape_2d = compressShapeTo2D(expected_shape); auto shape_2d = compressShapeTo2D(shape); - NVTE_CHECK(shape == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, + NVTE_CHECK(shape_2d == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, ") and column-wise data (2D shape=", shape_2d, ") do not match"); shape = expected_shape; } From 486da6681c9a63cc083f14e11dddde91bd8434fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=99=E5=88=92?= Date: Fri, 13 Feb 2026 16:51:02 +0800 Subject: [PATCH 6/7] add unittest and doctring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 乙划 --- .../nvfp4/test_nvfp4_quantize_exact.py | 40 +++++++++++++++++++ transformer_engine/pytorch/csrc/quantizer.cpp | 11 +++++ 2 files changed, 51 insertions(+) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..910f9c36ca 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -489,3 +489,43 @@ def test_nvfp4_quantization_noncontiguous_inputs( torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_nvfp4_3d_shape_quantization( + M: int, + N: int, + with_2d_quantization: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + # Input + x = torch.randn((M, 4, N), dtype=torch.bfloat16, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + q_x = nvfp4_quantizer(x) + x *= 2 + nvfp4_quantizer.update_quantized(x, q_x) + assert q_x._rowwise_data is not None + assert len(q_x._rowwise_data.shape) == 3 + assert q_x._columnwise_data is not None + assert len(q_x._columnwise_data.shape) == 2 \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d61564ddfc..61e1066828 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1263,6 +1263,17 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w return {std::move(out_cpp), std::move(out_py)}; } +/** + * @brief Compress an N-D shape into a 2-D shape by flattening all but the last dimension. + * + * This utility is intended for comparing N-dimensional tensor shapes in a 2D space: + * it multiplies (flattens) every dimension except the final one into a single leading + * dimension, and keeps the last dimension unchanged. + * + * Example: [d0, d1, d2, ..., d{n-2}, d{n-1}] -> [d0*d1*...*d{n-2}, d{n-1}] + * + * If the input has 2 or fewer dimensions, it is returned unchanged. + */ std::vector compressShapeTo2D(const std::vector& data) { // If 2 or fewer elements, return as-is if (data.size() <= 2) { From 075ebf6f757b96e99115e0061b2b6be8da13d68e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 08:53:27 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 910f9c36ca..36f460172d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -528,4 +528,4 @@ def test_nvfp4_3d_shape_quantization( assert q_x._rowwise_data is not None assert len(q_x._rowwise_data.shape) == 3 assert q_x._columnwise_data is not None - assert len(q_x._columnwise_data.shape) == 2 \ No newline at end of file + assert len(q_x._columnwise_data.shape) == 2