From 7b37680375be346d0f251add63b628dbbbda6e1c Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 6 Nov 2025 18:45:28 +0000 Subject: [PATCH 01/40] Support for NVFP4 primary weights Signed-off-by: qiyuw --- transformer_engine/pytorch/module/base.py | 37 +++++++++++ transformer_engine/pytorch/tensor/utils.py | 75 ++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 838ac5281c..c7ce993a98 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -46,6 +46,7 @@ from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -656,6 +657,8 @@ def __init__(self) -> None: self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + # NVFP4 reuses FP8 GlobalStateManager + self.primary_weights_in_nvfp4 = FP8GlobalStateManager.with_fp8_parameters() self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None @@ -1326,6 +1329,40 @@ def clear(self): setattr(self, name, param) + # NVFP4 primary weights path + if self.primary_weights_in_nvfp4 and fp8_meta_index is not None: + + high_precision_init_val = None + if self.preserve_high_precision_init_val: + high_precision_init_val = getattr(self, name).detach().cpu() + + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") + if not isinstance(quantizer, NVFP4Quantizer): + # Skip if this meta index is not NVFP4 (e.g., activations) + pass + else: + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + quantizer.internal = False + nvfp4_param = quantizer(getattr(self, name)) + nvfp4_param = torch.nn.Parameter(nvfp4_param) + if high_precision_init_val is not None: + def get(self_): + if hasattr(self_, "_high_precision_init_val"): + return self_._high_precision_init_val + return None + + def clear(self_): + if hasattr(self_, "_high_precision_init_val"): + del self_._high_precision_init_val + + nvfp4_param._high_precision_init_val = high_precision_init_val + nvfp4_param.get_high_precision_init_val = MethodType(get, nvfp4_param) + nvfp4_param.clear_high_precision_init_val = MethodType(clear, nvfp4_param) + + setattr(self, name, nvfp4_param) + @abstractmethod def forward(self): """Needs override.""" diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index cc02494013..4363d558c6 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -12,6 +12,7 @@ from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier @@ -454,6 +455,80 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) +def cast_master_weights_to_nvfp4( + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +): + r"""Helper function to cast master weights to NVFP4 primary weights using partial shards. + + Mirrors FP8 partial-cast: compute local amax per shard, all-reduce to global amax, + then cast only the shard slice into NVFP4 storages. Requires backend NVFP4 partial ops. + """ + + if fsdp_shard_model_weights is None: + use_fsdp_shard_model_weights = False + fsdp_shard_model_weights = [None] * len(model_weights) + else: + use_fsdp_shard_model_weights = True + + if len(model_weights) == 0: + return + + device = model_weights[0].device + packed_amaxes = torch.zeros(len(model_weights), dtype=torch.float32, device=device) + + # Step 1: local amax per shard + for i, (model_weight, master_weight, start_offset, _) in enumerate( + zip(model_weights, master_weights, start_offsets, fsdp_shard_model_weights) + ): + if not isinstance(model_weight, NVFP4Tensor): + continue + if master_weight is None: + continue + h, w = model_weight.shape + try: + tex.nvfp4_compute_partial_amax(master_weight, packed_amaxes[i : i + 1], h, w, start_offset) + except AttributeError as e: + raise NotImplementedError( + "Missing NVFP4 partial amax kernel: tex.nvfp4_compute_partial_amax" + ) from e + + # Step 2: all-reduce to global amax + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # Step 3: partial cast per shard + for i, (model_weight, master_weight, start_offset, model_weight_fragment) in enumerate( + zip(model_weights, master_weights, start_offsets, fsdp_shard_model_weights) + ): + if not isinstance(model_weight, NVFP4Tensor): + continue + if master_weight is None: + continue + + quantizer = model_weight._get_quantizer() + if not isinstance(quantizer, NVFP4Quantizer): + raise ValueError( + f"cast_master_weights_to_nvfp4 expects NVFP4Quantizer, got {type(quantizer)}" + ) + quantizer.set_usage(rowwise=True, columnwise=True) + if hasattr(model_weight, "_reset_caches"): + model_weight._reset_caches() + + dst_tensor = model_weight if not use_fsdp_shard_model_weights else model_weight_fragment + h, w = model_weight.shape + try: + tex.nvfp4_partial_cast( + master_weight, + dst_tensor, + packed_amaxes[i : i + 1], + h, + w, + start_offset, + ) + except AttributeError as e: + raise NotImplementedError( + "Missing NVFP4 partial cast kernel: tex.nvfp4_partial_cast" + ) from e + def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an environment or object is using experimental Kitchen middleware. From 2c0750e3a6f12ac20fcc00cb8c7e6096bc23003a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 18 Nov 2025 00:10:11 +0000 Subject: [PATCH 02/40] partial amax and cast --- .../include/transformer_engine/recipe.h | 10 + transformer_engine/common/recipe/nvfp4.cu | 232 ++++++++++++++++++ .../csrc/extensions/nvfp4_2d_partial_cast.cpp | 49 ++++ .../pytorch/csrc/extensions/pybind.cpp | 10 + .../pytorch/csrc/extensions/recipe.cpp | 2 + transformer_engine/pytorch/tensor/utils.py | 136 +++++----- 6 files changed, 378 insertions(+), 61 deletions(-) mode change 100644 => 100755 transformer_engine/common/include/transformer_engine/recipe.h mode change 100644 => 100755 transformer_engine/common/recipe/nvfp4.cu create mode 100755 transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/pybind.cpp mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/recipe.cpp mode change 100644 => 100755 transformer_engine/pytorch/tensor/utils.py diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h old mode 100644 new mode 100755 index 6e1e9dd7ac..0a3b0dd41d --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -126,6 +126,16 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, cudaStream_t stream); +// NVFP4 2D (16x16) partial-shard APIs +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, + cudaStream_t stream); + +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, + size_t start_offset, size_t block_len, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu old mode 100644 new mode 100755 index 5ebc7ba4f3..9a94d84f62 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -7,6 +7,7 @@ #include #include +#include #include "../common.h" #include "../utils.cuh" @@ -16,6 +17,8 @@ namespace nvfp4_recipe { // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); +constexpr int kTileDim = 16; +constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, @@ -24,9 +27,234 @@ __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const floa *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_compute_partial_amax_kernel(const IType *input, float *amax_ptr, + const size_t amax_stride_h, const size_t amax_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kLoopsPerCol = kTileDim / kNumWarps; + + const int tile_col = blockIdx.x; + const int tile_row = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + __shared__ float smem[kNumWarps]; + float amax = 0.0f; + + for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) { + size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp; + for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) { + size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp); + size_t idx = r * w + c; + if (r < h && c < w && idx >= start_offset && idx < end_offset) { + float other_amax = fabs(static_cast(input_minus_offset[idx])); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + + if (threadIdx.x % kThreadsPerWarp == 0) { + smem[threadIdx.x / kThreadsPerWarp] = amax; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + for (int i = 0; i < kNumWarps; ++i) { + float other_amax = smem[i]; + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax; + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr, + const size_t scale_stride_h, const size_t scale_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + using transformer_engine::Vec; + + static_assert(sizeof(OType) == 1); + constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType); + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kRowsPerWarp = kTileDim / kNumWarps; + + __shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + + const int tile_w = blockIdx.x; + const int tile_h = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + OType *output_minus_offset = output - start_offset; + + const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + + // Load input data into shared memory + bool skip_store = true; + for (int i = 0; i < kRowsPerWarp; ++i) { + for (int j = 0; j < kLoopsPerRow; ++j) { + const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j; + const int h_in_input = tile_h * kTileDim + h_in_smem; + const int w_in_input = tile_w * kTileDim + w_in_smem; + const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; + if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && + idx_in_input < end_offset) { + float inp = static_cast(input_minus_offset[idx_in_input]) * scale; + smem[h_in_smem][w_in_smem] = static_cast(inp); + skip_store = false; + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); + skip_store = skip_store && other_skip_store; + } + skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); + if (skip_store) { + return; + } + + // Store the casted data into the output. + // Note that this store operation might write "out-of-bounds", but it is intentional: + // 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region + // from start_offset to end_offset), not the boundary of the entire output memory. Therefore, + // this out-of-bounds write will not cause illegal memory access. + // 2. We assume that the subsequent all-gather operation happens in-place, so any parts that + // should not be updated here will be overwritten by the all-gather. + // This tricky approach allows us to avoid checking whether each output index falls within + // [start, end), resulting in a significant performance improvement. + Vec vec_output; + for (int i = 0; i < kRowsPerWarp; ++i) { + const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank; + for (int j = 0; j < kNumOutputElemsPerBank; ++j) { + vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j]; + } + const int row_in_output = tile_h * kTileDim + row_in_smem; + const int col_in_output = tile_w * kTileDim + col_in_smem; + const size_t idx_in_output = static_cast(row_in_output) * w + col_in_output; + if (row_in_output < h) { + if constexpr (kWidthAligned) { + vec_output.store_to(output_minus_offset + idx_in_output); + } else { + int num = min(static_cast(kNumOutputElemsPerBank), + static_cast(col_in_output < w ? w - col_in_output : 0)); + vec_output.store_to_elts(output_minus_offset, idx_in_output, num); + } + } + } +} + +void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + nvfp4_2d_compute_partial_amax_kernel + <<>>(reinterpret_cast(inp.data.dptr), + reinterpret_cast(amax.data.dptr), + amax_stride_h, amax_stride_w, h, w, start_offset, + len);) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, + size_t w, size_t scale_stride_h, size_t scale_stride_w, + size_t start_offset, size_t block_len, const DType out_dtype, + cudaStream_t stream) { + NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + out_dtype, fp8_type, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + h, w, start_offset, len);))) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, + size_t w, size_t amax_stride_h, + size_t amax_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_compute_partial_amax); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_compute_partial_amax( + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h, + amax_stride_w, start_offset, block_len, stream); +} + +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, + const NVTETensor scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, + size_t start_offset, size_t block_len, + const NVTEDType out_dtype, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_partial_cast( + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h, + w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast(out_dtype), + stream); +} + void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, @@ -52,3 +280,7 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } + + + + diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp new file mode 100755 index 0000000000..71d15d95f2 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -0,0 +1,49 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); + TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); + TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || + tensor.scalar_type() == at::ScalarType::BFloat16, + "tensor must be a float or bfloat16 tensor"); + + const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(amax); + + nvte_nvfp4_2d_compute_partial_amax( + tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), amax.stride(1), start_offset, + block_len, at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, + size_t h, size_t w, size_t start_offset, size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(inp.scalar_type() == at::ScalarType::Float || + inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), + scale.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); +} + +} // namespace transformer_engine::pytorch + + diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100644 new mode 100755 index 3b81393dbd..3b3cd1bf61 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -276,6 +276,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D + m.def("nvfp4_2d_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, + "Compute partial amax from master weights for NVFP4 2D", py::arg("tensor"), py::arg("amax"), + py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); + m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, + "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), + py::arg("global_amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), + py::arg("block_len") = 16, py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp old mode 100644 new mode 100755 index 8d1d865604..278cfa205a --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -67,3 +67,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio } } // namespace transformer_engine::pytorch + +namespace transformer_engine::pytorch {} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100644 new mode 100755 index 89ca974eac..3e9089c240 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -17,6 +17,7 @@ from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..utils import is_non_tn_fp8_gemm_supported +from ..constants import NVFP4_BLOCK_SCALING_SIZE def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -437,80 +438,93 @@ def _cast_master_weights_to_fp8_blockwise_scaling( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) - -def cast_master_weights_to_nvfp4( - model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +# revisit this later +def _cast_master_weights_to_nvfp4_2d( + params, group, use_fsdp_shard_model_weights=False ): - r"""Helper function to cast master weights to NVFP4 primary weights using partial shards. + r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. - Mirrors FP8 partial-cast: compute local amax per shard, all-reduce to global amax, - then cast only the shard slice into NVFP4 storages. Requires backend NVFP4 partial ops. + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. """ - if fsdp_shard_model_weights is None: - use_fsdp_shard_model_weights = False - fsdp_shard_model_weights = [None] * len(model_weights) - else: - use_fsdp_shard_model_weights = True + device = params[0][0].device + block_len = NVFP4_BLOCK_SCALING_SIZE - if len(model_weights) == 0: - return + dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device) - device = model_weights[0].device - packed_amaxes = torch.zeros(len(model_weights), dtype=torch.float32, device=device) + cu_amax_sizes = [0] + scale_shapes: List[tuple[int, int]] = [] + for model_weight, _, _, _ in params: + quantizer = model_weight._get_quantizer() + assert isinstance(quantizer, NVFP4Quantizer) + assert quantizer.with_2d_quantization, "NVFP4 2D quantization must be enabled." + scale_shape = quantizer.get_scale_shape(model_weight.shape, columnwise=False) + scale_shapes.append(scale_shape) + num_amaxes = scale_shape[0] * scale_shape[1] + cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes) - # Step 1: local amax per shard - for i, (model_weight, master_weight, start_offset, _) in enumerate( - zip(model_weights, master_weights, start_offsets, fsdp_shard_model_weights) - ): - if not isinstance(model_weight, NVFP4Tensor): - continue - if master_weight is None: - continue - h, w = model_weight.shape - try: - tex.nvfp4_compute_partial_amax(master_weight, packed_amaxes[i : i + 1], h, w, start_offset) - except AttributeError as e: - raise NotImplementedError( - "Missing NVFP4 partial amax kernel: tex.nvfp4_compute_partial_amax" - ) from e + packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) - # Step 2: all-reduce to global amax - torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + amaxes: List[torch.Tensor] = [] + scales: List[torch.Tensor] = [] + scale_inv_tmp: List[torch.Tensor] = [] + fp8_scale_targets: List[torch.Tensor] = [] - # Step 3: partial cast per shard - for i, (model_weight, master_weight, start_offset, model_weight_fragment) in enumerate( - zip(model_weights, master_weights, start_offsets, fsdp_shard_model_weights) - ): - if not isinstance(model_weight, NVFP4Tensor): - continue - if master_weight is None: - continue + for i, (model_weight, master_weight, start_offset, _) in enumerate(params): + scale_shape = scale_shapes[i] + amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) + scale = torch.empty(scale_shape, dtype=torch.float32, device=device) + inv_tmp = torch.empty_like(scale) - quantizer = model_weight._get_quantizer() - if not isinstance(quantizer, NVFP4Quantizer): - raise ValueError( - f"cast_master_weights_to_nvfp4 expects NVFP4Quantizer, got {type(quantizer)}" + assert model_weight._rowwise_scale_inv is not None + + amaxes.append(amax) + scales.append(scale) + scale_inv_tmp.append(inv_tmp) + fp8_scale_targets.append(model_weight._rowwise_scale_inv) + + if master_weight is not None and master_weight.numel() > 0: + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + tex.nvfp4_2d_compute_partial_amax( + master_weight, amax, h, w, start_offset, block_len ) - quantizer.set_usage(rowwise=True, columnwise=True) - if hasattr(model_weight, "_reset_caches"): - model_weight._reset_caches() - dst_tensor = model_weight if not use_fsdp_shard_model_weights else model_weight_fragment + if packed_amaxes.numel() > 0: + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + if len(amaxes) > 0: + multi_tensor_applier( + multi_tensor_compute_scale_and_scale_inv, + dummy_overflow_buf, + [amaxes, scales, scale_inv_tmp], + 6.0, + False, + 0.0, + ) + + for inv_tmp, target in zip(scale_inv_tmp, fp8_scale_targets): + fp8_view = inv_tmp.to(dtype=torch.float8_e4m3fn).view(torch.uint8) + target.copy_(fp8_view) + + for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( + params, scales + ): + if master_weight is None or master_weight.numel() == 0: + continue + + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] + assert len(model_weight.shape) == 2 h, w = model_weight.shape - try: - tex.nvfp4_partial_cast( - master_weight, - dst_tensor, - packed_amaxes[i : i + 1], - h, - w, - start_offset, - ) - except AttributeError as e: - raise NotImplementedError( - "Missing NVFP4 partial cast kernel: tex.nvfp4_partial_cast" - ) from e + tex.nvfp4_2d_partial_cast(master_weight, model_weight_fragment, scale, h, w, start_offset, block_len) def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ From c2acb0694383c3ab4f7e42da41f7b99c4dbf941d Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 20 Nov 2025 22:22:47 +0000 Subject: [PATCH 03/40] add unit tests --- .../run_cast_master_weights_to_fp8.py | 67 +++++- .../test_cast_master_weights_to_fp8.py | 11 +- transformer_engine/common/recipe/nvfp4.cu | 222 +++++++++++++----- transformer_engine/pytorch/csrc/extensions.h | 6 + transformer_engine/pytorch/tensor/__init__.py | 9 +- transformer_engine/pytorch/tensor/utils.py | 65 ++++- 6 files changed, 294 insertions(+), 86 deletions(-) mode change 100755 => 100644 transformer_engine/common/recipe/nvfp4.cu mode change 100755 => 100644 transformer_engine/pytorch/tensor/utils.py diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index 2f11a24ee8..3c44718078 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -18,6 +18,7 @@ Float8CurrentScaling, Float8BlockScaling, Format, + NVFP4BlockScaling, Recipe, ) import transformer_engine.pytorch as te @@ -25,8 +26,12 @@ QuantizedTensor, Float8Tensor, Float8BlockwiseQTensor, + NVFP4Tensor, +) +from transformer_engine.pytorch.tensor import ( + cast_master_weights_to_fp8, + cast_master_weights_to_nvfp4, ) -from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data @@ -44,6 +49,12 @@ def _get_raw_data(quantized_tensor): quantized_tensor._rowwise_data.dtype == torch.uint8 ), "Float8BlockwiseQTensor _rowwise_data must be uint8" return quantized_tensor._rowwise_data + elif isinstance(quantized_tensor, NVFP4Tensor): + assert hasattr(quantized_tensor, "_rowwise_data"), "NVFP4Tensor missing _rowwise_data" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "NVFP4Tensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data else: raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") @@ -58,6 +69,7 @@ def __init__(self, weights, lr, dp_group): self.weights = weights self.lr = lr self.dp_group = dp_group + self.quantized_format = self._detect_quantized_format() # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer self.offsets = [0] @@ -121,6 +133,15 @@ def __init__(self, weights, lr, dp_group): [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device ) self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + self.quantized_format = self._detect_quantized_format() + + def _detect_quantized_format(self): + for weight in self.weights: + if isinstance(weight, NVFP4Tensor): + return "nvfp4" + if isinstance(weight, (Float8Tensor, Float8BlockwiseQTensor)): + return "fp8" + return None def step(self): # ----------------------------------------------------------------------------------------- @@ -160,12 +181,16 @@ def step(self): # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight # ----------------------------------------------------------------------------------------- if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) - cast_master_weights_to_fp8( - self.weights, self.master_weights, self.start_offsets, self.dp_group - ) + for weight in self.weights: + assert isinstance(weight, QuantizedTensor) + if self.quantized_format == "nvfp4": + cast_master_weights_to_nvfp4( + self.weights, self.master_weights, self.start_offsets, self.dp_group + ) + else: + cast_master_weights_to_fp8( + self.weights, self.master_weights, self.start_offsets, self.dp_group + ) else: # BF16 weights case for weight, master_weight, start_offset in zip( @@ -352,6 +377,14 @@ def zero_grad(self): weight.grad = None weight.main_grad.zero_() + def _detect_quantized_format(self): + for weight in self.weights: + if isinstance(weight, NVFP4Tensor): + return "nvfp4" + if isinstance(weight, (Float8Tensor, Float8BlockwiseQTensor)): + return "fp8" + return None + def step(self): """ Perform an optimization step for the distributed sharded model. @@ -394,10 +427,14 @@ def step(self): if local_weight is None: local_weights.append(None) continue - local_weights.append(local_weight) - cast_master_weights_to_fp8( + cast_fn = ( + cast_master_weights_to_nvfp4 + if self.quantized_format == "nvfp4" + else cast_master_weights_to_fp8 + ) + cast_fn( self.weights, self.master_weights, [idx[0] for idx in self.weight_indices], @@ -562,6 +599,8 @@ def quantization_recipe(quantization) -> Recipe: return Float8CurrentScaling(fp8_format=fp8_format) elif quantization == "fp8_block": return Float8BlockScaling(fp8_format=fp8_format) + elif quantization == "nvfp4": + return NVFP4BlockScaling() else: raise ValueError(f"Unsupported quantization: {quantization}") @@ -672,11 +711,19 @@ def main(argv=None, namespace=None): parser = argparse.ArgumentParser() parser.add_argument( - "--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"] + "--quantization", + type=str, + default=None, + choices=["fp8", "fp8_cs", "fp8_block", "nvfp4"], ) args = parser.parse_args(argv, namespace) dp_group = dist.new_group(backend="nccl") + if args.quantization == "nvfp4": + nvfp4_available, reason = te.is_nvfp4_available(return_reason=True) + if not nvfp4_available: + raise RuntimeError(f"NVFP4 not available: {reason}") + _test_mini_optimizer(dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group) _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 5bf46b8d5f..7767859f63 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -8,7 +8,11 @@ import pytest import torch -from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available +from transformer_engine.pytorch import ( + is_fp8_available, + is_fp8_block_scaling_available, + is_nvfp4_available, +) if torch.cuda.device_count() < 2: @@ -18,6 +22,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( return_reason=True ) +nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(2, torch.cuda.device_count()) @@ -31,10 +36,12 @@ def _run_test(quantization): assert result.returncode == 0 -@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"]) +@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block", "nvfp4"]) def test_cast_master_weights_to_fp8(quantization): if quantization in ("fp8", "fp8_cs") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "fp8_block" and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) _run_test(quantization) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu old mode 100755 new mode 100644 index 9a94d84f62..aefd9e3553 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -11,10 +11,63 @@ #include "../common.h" #include "../utils.cuh" +#include "../util/ptx.cuh" namespace transformer_engine { namespace nvfp4_recipe { +/* + * --------------------------------------------------------------------------- + * NVFP4 2D PARTIAL-SHARD KERNEL DESIGN + * + * These kernels mirror the FP8 block-scaling helpers but operate on shard-local + * slices and nibble-packed FP4 rowwise buffers. One CUDA block covers a logical + * 16x16 tile (grid = ceil(W/16) x ceil(H/16), blockDim = 256 threads). + * + * 1) Partial Amax (`nvfp4_2d_compute_partial_amax_kernel`) + * - Warps sweep the tile using nested loops, accumulating local maxima only + * for elements in [start_offset, start_offset + len). + * - Shared memory reduces the 8 warp maxima; the block writes a float into + * `amax_ptr[tile_row * stride_h + tile_col * stride_w]`. + * + * Tile/warp mapping (each '#' = elements visited by that warp): + * + * +------------------+ + * |########..........| Warp 0 + * |########..........| Warp 1 + * | ... | + * |########..........| Warp 7 + * +------------------+ + * + * 2) Partial Cast (`nvfp4_2d_partial_cast_kernel`) + * - Stage the tile into shared memory (same pattern as FP8). + * - For each 4-value group, build float2 pairs and call + * `ptx::mul_cvt_fp32_to_fp4_4x`, producing packed FP4 nibbles. + * - Compute a shard-local byte index and update only the owned nibble(s) + * using read-modify-write: + * + * packed_bits = [mw3 | mw2 | mw1 | mw0] + * byte_idx = (ref_elem_idx - start_offset) >> 1 + * if elem_idx % 2 == 0: // low nibble + * byte = (byte & 0xF0) | nibble + * else: // high nibble + * byte = (byte & 0x0F) | (nibble << 4) + * + * Thread coverage inside a tile: + * + * rows: 16 columns: 16 + * Warp 0 -> rows 0-1 lanes sweep cols 0..3, 4..7, ... + * Warp 1 -> rows 2-3 (groups of 4 elements per thread) + * ... + * Warp 7 -> rows 14-15 + * + * The host helper `_cast_master_weights_to_nvfp4_2d` reduces per-tile amax + * values, packs the resulting FP32 scales into the uint8 `_rowwise_scale_inv`, + * and launches `tex.nvfp4_2d_partial_cast`. The resulting bytes match TE’s full + * NVFP4 quantizer, so downstream GEMMs/checkpoints remain unchanged. + * --------------------------------------------------------------------------- + */ + // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); constexpr int kTileDim = 16; @@ -34,9 +87,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t h, const size_t w, const size_t start_offset, const size_t len) { constexpr int kThreadsPerWarp = 32; - constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; - constexpr int kLoopsPerCol = kTileDim / kNumWarps; + constexpr int kLoopsPerCol = (kTileDim + kNumWarps - 1) / kNumWarps; const int tile_col = blockIdx.x; const int tile_row = blockIdx.y; @@ -84,94 +137,136 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } -template +template __global__ void __launch_bounds__(kThreadsPerBlock) - nvfp4_2d_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr, - const size_t scale_stride_h, const size_t scale_stride_w, - const size_t h, const size_t w, const size_t start_offset, - const size_t len) { - using transformer_engine::Vec; - - static_assert(sizeof(OType) == 1); - constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType); + nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *scale_ptr, + const size_t scale_stride_h, const size_t scale_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + constexpr int kNumOutputElemsPerBank = 4; constexpr int kThreadsPerWarp = 32; - constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; - constexpr int kRowsPerWarp = kTileDim / kNumWarps; + constexpr int kRowsPerWarp = (kTileDim + kNumWarps - 1) / kNumWarps; - __shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + __shared__ float smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; const int tile_w = blockIdx.x; const int tile_h = blockIdx.y; - const size_t end_offset = start_offset + len; + const size_t shard_end = start_offset + len; const IType *input_minus_offset = input - start_offset; - OType *output_minus_offset = output - start_offset; - const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + const float tile_scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + const float2 scale_vec = make_float2(tile_scale, tile_scale); - // Load input data into shared memory bool skip_store = true; for (int i = 0; i < kRowsPerWarp; ++i) { for (int j = 0; j < kLoopsPerRow; ++j) { const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j; + if (h_in_smem >= kTileDim || w_in_smem >= kTileDim) { + continue; + } const int h_in_input = tile_h * kTileDim + h_in_smem; const int w_in_input = tile_w * kTileDim + w_in_smem; const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && - idx_in_input < end_offset) { - float inp = static_cast(input_minus_offset[idx_in_input]) * scale; - smem[h_in_smem][w_in_smem] = static_cast(inp); + idx_in_input < shard_end) { + smem[h_in_smem][w_in_smem] = static_cast(input_minus_offset[idx_in_input]); skip_store = false; } } } for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { - bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); - skip_store = skip_store && other_skip_store; + bool other = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); + skip_store = skip_store && other; } skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); if (skip_store) { return; } - // Store the casted data into the output. - // Note that this store operation might write "out-of-bounds", but it is intentional: - // 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region - // from start_offset to end_offset), not the boundary of the entire output memory. Therefore, - // this out-of-bounds write will not cause illegal memory access. - // 2. We assume that the subsequent all-gather operation happens in-place, so any parts that - // should not be updated here will be overwritten by the all-gather. - // This tricky approach allows us to avoid checking whether each output index falls within - // [start, end), resulting in a significant performance improvement. - Vec vec_output; for (int i = 0; i < kRowsPerWarp; ++i) { const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int row_in_output = tile_h * kTileDim + row_in_smem; + if (row_in_output >= h) { + continue; + } const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank; - for (int j = 0; j < kNumOutputElemsPerBank; ++j) { - vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j]; + if (col_in_smem >= kTileDim) { + continue; } - const int row_in_output = tile_h * kTileDim + row_in_smem; const int col_in_output = tile_w * kTileDim + col_in_smem; - const size_t idx_in_output = static_cast(row_in_output) * w + col_in_output; - if (row_in_output < h) { - if constexpr (kWidthAligned) { - vec_output.store_to(output_minus_offset + idx_in_output); - } else { - int num = min(static_cast(kNumOutputElemsPerBank), - static_cast(col_in_output < w ? w - col_in_output : 0)); - vec_output.store_to_elts(output_minus_offset, idx_in_output, num); + + float vals[kNumOutputElemsPerBank]; + bool mask[kNumOutputElemsPerBank]; + size_t elem_idx[kNumOutputElemsPerBank]; + bool any_valid = false; + + for (int j = 0; j < kNumOutputElemsPerBank; ++j) { + const int col = col_in_output + j; + const bool in_width = col < w; + const size_t idx = static_cast(row_in_output) * w + col; + elem_idx[j] = idx; + const bool in_shard = in_width && idx >= start_offset && idx < shard_end; + mask[j] = in_shard; + const bool in_tile = (col_in_smem + j) < kTileDim; + const float tile_val = + in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; + vals[j] = in_shard ? tile_val : 0.0f; + any_valid |= in_shard; + } + + if (!any_valid) { + continue; + } + + const float2 in01 = make_float2(vals[0], vals[1]); + const float2 in23 = make_float2(vals[2], vals[3]); + const auto packed = + transformer_engine::ptx::mul_cvt_fp32_to_fp4_4x(in01, in23, scale_vec, 0); + const uint16_t packed_bits = reinterpret_cast(packed); + + for (int pair = 0; pair < 2; ++pair) { + const int first = pair * 2; + const int second = first + 1; + if (!mask[first] && !mask[second]) { + continue; + } + const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second]; + const size_t byte_idx = (ref_idx - start_offset) >> 1; + uint8_t byte = output[byte_idx]; + + if (mask[first]) { + const uint8_t nibble = + static_cast((packed_bits >> (4 * first)) & 0xF); + if ((elem_idx[first] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } } + + if (mask[second]) { + const uint8_t nibble = + static_cast((packed_bits >> (4 * second)) & 0xF); + if ((elem_idx[second] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } + } + + output[byte_idx] = byte; } } } void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, - size_t amax_stride_h, size_t amax_stride_w, - size_t start_offset, size_t block_len, - cudaStream_t stream) { - NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); size_t len = inp.numel(); @@ -196,10 +291,11 @@ void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size } void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, - size_t w, size_t scale_stride_h, size_t scale_stride_w, - size_t start_offset, size_t block_len, const DType out_dtype, - cudaStream_t stream) { - NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); + size_t w, size_t scale_stride_h, size_t scale_stride_w, + size_t start_offset, size_t block_len, const DType out_dtype, + cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); + NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); size_t len = inp.numel(); @@ -215,16 +311,14 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, siz TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( inp.dtype(), inp_dtype, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - out_dtype, fp8_type, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - w % kTileDim == 0, kWidthAligned, - nvfp4_2d_partial_cast_kernel + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel <<>>( reinterpret_cast(inp.data.dptr), - reinterpret_cast(out.data.dptr), + reinterpret_cast(out.data.dptr), reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, - h, w, start_offset, len);))) + h, w, start_offset, len);)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -243,16 +337,14 @@ void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, s } void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, - const NVTETensor scale, size_t h, size_t w, - size_t scale_stride_h, size_t scale_stride_w, - size_t start_offset, size_t block_len, - const NVTEDType out_dtype, cudaStream_t stream) { + const NVTETensor scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); using namespace transformer_engine; nvfp4_recipe::nvfp4_2d_partial_cast( *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h, - w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast(out_dtype), - stream); + w, scale_stride_h, scale_stride_w, start_offset, block_len, stream); } void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 79fb798422..3576213a4b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -335,6 +335,12 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const size_t h, size_t w, size_t start_offset, size_t block_len, const DType out_dtype); +void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len); + +void nvfp4_2d_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, + size_t h, size_t w, size_t start_offset, size_t block_len); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index ada624a902..68863d0290 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -21,7 +21,11 @@ from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer -from .utils import cast_master_weights_to_fp8, replace_raw_data +from .utils import ( + cast_master_weights_to_fp8, + cast_master_weights_to_nvfp4, + replace_raw_data, +) __all__ = [ "Quantizer", @@ -42,6 +46,9 @@ "NVFP4Tensor", "prepare_for_saving", "restore_from_saved", + "cast_master_weights_to_fp8", + "cast_master_weights_to_nvfp4", + "replace_raw_data", ] diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100755 new mode 100644 index 3e9089c240..2453f96d86 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Helper functions for using fp8 tensors as weights""" +"""Helper functions for using fp8/nvfp4 tensors as weights""" from typing import Optional, Union, List import torch @@ -43,6 +43,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): new_raw_data.detach().copy_(old_raw_data) tensor._rowwise_data = new_raw_data del old_raw_data + elif isinstance(tensor, NVFP4Tensor): + old_rowwise = tensor._rowwise_data + assert old_rowwise.dtype == new_raw_data.dtype, "The data types of raw data don't match" + new_rowwise_data.detach().copy_(old_rowwise) + tensor._rowwise_data = new_rowwise_data + del old_rowwise elif isinstance(tensor, MXFP8Tensor): raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") else: @@ -145,6 +151,44 @@ def cast_master_weights_to_fp8( ) +def cast_master_weights_to_nvfp4( + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +): + """Helper to cast master weights to NVFP4 primary weights.""" + + nvfp4_params = [] + + if fsdp_shard_model_weights is None: + use_fsdp_shard_model_weights = False + fsdp_shard_model_weights = [None] * len(model_weights) + else: + use_fsdp_shard_model_weights = True + + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( + model_weights, master_weights, start_offsets, fsdp_shard_model_weights + ): + if hasattr(model_weight, "clear_high_precision_init_val"): + model_weight.clear_high_precision_init_val() + + if master_weight is not None: + master_weight = master_weight.to(model_weight.dtype) + + quantizer = model_weight._get_quantizer() + if isinstance(quantizer, NVFP4Quantizer): + nvfp4_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + else: + raise ValueError( + f"cast_master_weights_to_nvfp4 only supports NVFP4 tensors, got {type(model_weight)}" + ) + + if len(nvfp4_params) > 0: + _cast_master_weights_to_nvfp4_2d( + nvfp4_params, group, use_fsdp_shard_model_weights=use_fsdp_shard_model_weights + ) + + def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. @@ -501,12 +545,12 @@ def _cast_master_weights_to_nvfp4_2d( if len(amaxes) > 0: multi_tensor_applier( - multi_tensor_compute_scale_and_scale_inv, - dummy_overflow_buf, - [amaxes, scales, scale_inv_tmp], - 6.0, - False, - 0.0, + multi_tensor_compute_scale_and_scale_inv, + dummy_overflow_buf, + [amaxes, scales, scale_inv_tmp], + 6.0, + False, + 0.0, ) for inv_tmp, target in zip(scale_inv_tmp, fp8_scale_targets): @@ -524,7 +568,9 @@ def _cast_master_weights_to_nvfp4_2d( model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] assert len(model_weight.shape) == 2 h, w = model_weight.shape - tex.nvfp4_2d_partial_cast(master_weight, model_weight_fragment, scale, h, w, start_offset, block_len) + tex.nvfp4_2d_partial_cast( + master_weight, model_weight_fragment, scale, h, w, start_offset, block_len + ) def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ @@ -544,6 +590,9 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten elif isinstance(model_weight, Float8BlockwiseQTensor): # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() + elif isinstance(model_weight, NVFP4Tensor): + # TODO: Add create_columnwise for NVFP4Tensor + pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") From d6bd790d3b0eb9d301c044682bfc7c40fea74d64 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 21 Nov 2025 19:51:16 +0000 Subject: [PATCH 04/40] add global amax --- .../include/transformer_engine/recipe.h | 5 +- transformer_engine/common/recipe/nvfp4.cu | 41 ++++++++----- .../csrc/extensions/nvfp4_2d_partial_cast.cpp | 12 +++- .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/tensor/utils.py | 58 +++++++++++++++---- 5 files changed, 87 insertions(+), 33 deletions(-) mode change 100755 => 100644 transformer_engine/common/include/transformer_engine/recipe.h mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/pybind.cpp diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h old mode 100755 new mode 100644 index 0a3b0dd41d..1fe9c81a1a --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -133,8 +133,9 @@ void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, s cudaStream_t stream); void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, - size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, - size_t start_offset, size_t block_len, cudaStream_t stream); + const NVTETensor global_scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index aefd9e3553..6dea833aaf 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -139,10 +139,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) template __global__ void __launch_bounds__(kThreadsPerBlock) - nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *scale_ptr, + nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *decode_scale_ptr, const size_t scale_stride_h, const size_t scale_stride_w, - const size_t h, const size_t w, const size_t start_offset, - const size_t len) { + const float *global_scale_ptr, const size_t h, const size_t w, + const size_t start_offset, const size_t len) { constexpr int kNumOutputElemsPerBank = 4; constexpr int kThreadsPerWarp = 32; constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; @@ -156,8 +156,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t shard_end = start_offset + len; const IType *input_minus_offset = input - start_offset; - const float tile_scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; - const float2 scale_vec = make_float2(tile_scale, tile_scale); + float global_encode_scale = global_scale_ptr[0]; + if (global_encode_scale <= 0.f) { + global_encode_scale = 1.f; + } + const float global_decode_scale = 1.0f / global_encode_scale; + + const float tile_decode_scale = + decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + float tile_encode_val = (tile_decode_scale > 0.f) + ? 1.0f / (tile_decode_scale * global_decode_scale) + : TypeExtrema::max; + tile_encode_val = fminf(tile_encode_val, TypeExtrema::max); + const float2 scale_vec = make_float2(tile_encode_val, tile_encode_val); bool skip_store = true; for (int i = 0; i < kRowsPerWarp; ++i) { @@ -290,10 +301,10 @@ void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size NVTE_CHECK_CUDA(cudaGetLastError()); } -void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, - size_t w, size_t scale_stride_h, size_t scale_stride_w, - size_t start_offset, size_t block_len, const DType out_dtype, - cudaStream_t stream) { +void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, + const Tensor global_scale, size_t h, size_t w, size_t scale_stride_h, + size_t scale_stride_w, size_t start_offset, size_t block_len, + const DType out_dtype, cudaStream_t stream) { NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); @@ -318,7 +329,8 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, siz reinterpret_cast(inp.data.dptr), reinterpret_cast(out.data.dptr), reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, - h, w, start_offset, len);)) + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, + len);)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -336,15 +348,16 @@ void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, s amax_stride_w, start_offset, block_len, stream); } -void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, - const NVTETensor scale, size_t h, size_t w, +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + const NVTETensor global_scale, size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, size_t block_len, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); using namespace transformer_engine; nvfp4_recipe::nvfp4_2d_partial_cast( - *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h, - w, scale_stride_h, scale_stride_w, start_offset, block_len, stream); + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_scale), h, w, scale_stride_h, scale_stride_w, start_offset, + block_len, stream); } void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp old mode 100755 new mode 100644 index 71d15d95f2..f41a6eed29 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -27,10 +27,14 @@ void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, si } void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, - size_t h, size_t w, size_t start_offset, size_t block_len) { + const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len) { TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); TORCH_CHECK(inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, "input must be a float or bfloat16 tensor"); @@ -38,9 +42,11 @@ void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tens const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); - nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), - scale.stride(1), start_offset, block_len, + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), + global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), + start_offset, block_len, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100755 new mode 100644 index 3b3cd1bf61..7b1d079649 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -284,8 +284,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), - py::arg("global_amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), - py::arg("block_len") = 16, py::call_guard()); + py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), + py::arg("start_offset"), py::arg("block_len") = 16, py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 2453f96d86..5e22cc62ae 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -519,12 +519,17 @@ def _cast_master_weights_to_nvfp4_2d( scales: List[torch.Tensor] = [] scale_inv_tmp: List[torch.Tensor] = [] fp8_scale_targets: List[torch.Tensor] = [] + global_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) + global_amax_views: List[torch.Tensor] = [ + global_amaxes[i : i + 1] for i in range(len(params)) + ] for i, (model_weight, master_weight, start_offset, _) in enumerate(params): scale_shape = scale_shapes[i] amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) inv_tmp = torch.empty_like(scale) + global_amax_view = global_amax_views[i] assert model_weight._rowwise_scale_inv is not None @@ -539,27 +544,49 @@ def _cast_master_weights_to_nvfp4_2d( tex.nvfp4_2d_compute_partial_amax( master_weight, amax, h, w, start_offset, block_len ) + tex.compute_amax(master_weight, global_amax_view) if packed_amaxes.numel() > 0: torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + if global_amaxes.numel() > 0: + torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + global_scale_tensor = global_amaxes.clone() if len(amaxes) > 0: - multi_tensor_applier( - multi_tensor_compute_scale_and_scale_inv, - dummy_overflow_buf, - [amaxes, scales, scale_inv_tmp], - 6.0, - False, - 0.0, + finfo = torch.finfo(torch.float32) + fp4_max = 6.0 + fp8_max = 448.0 + tiny = finfo.tiny + + safe_global_amax = torch.clamp(global_amaxes, min=tiny) + global_encode_scales = torch.clamp((fp8_max * fp4_max) / safe_global_amax, max=finfo.max) + global_encode_scales = torch.where( + global_amaxes > 0, global_encode_scales, torch.ones_like(global_encode_scales) ) - + global_scale_tensor.copy_(global_encode_scales) + global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] + + for amax_tensor, scale_tensor, inv_tmp_tensor, global_scale in zip( + amaxes, scales, scale_inv_tmp, global_scale_views + ): + per_block_decode_scale = torch.clamp( + (amax_tensor / fp4_max) * global_scale, max=finfo.max + ) + scale_tensor.copy_(per_block_decode_scale) + inv_tmp_tensor.copy_(per_block_decode_scale) + # Update _rowwise_scale_inv with reduced per-block decode scales. for inv_tmp, target in zip(scale_inv_tmp, fp8_scale_targets): fp8_view = inv_tmp.to(dtype=torch.float8_e4m3fn).view(torch.uint8) target.copy_(fp8_view) + else: + global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( - params, scales - ): + for ( + (model_weight, master_weight, start_offset, model_weight_fragment), + per_block_decode_scale, + global_scale, + ) in zip(params, scales, global_scale_views): if master_weight is None or master_weight.numel() == 0: continue @@ -569,7 +596,14 @@ def _cast_master_weights_to_nvfp4_2d( assert len(model_weight.shape) == 2 h, w = model_weight.shape tex.nvfp4_2d_partial_cast( - master_weight, model_weight_fragment, scale, h, w, start_offset, block_len + master_weight, + model_weight_fragment, + per_block_decode_scale, + global_scale, + h, + w, + start_offset, + block_len, ) def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): From d607fd3735244e56d4ce742e815074e612d79831 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 21 Nov 2025 22:54:44 +0000 Subject: [PATCH 05/40] unit test v2 --- .../test_cast_master_weights_to_fp8.py | 168 +++++++++++++++++- 1 file changed, 161 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 0ff98e6cb7..c746eeb852 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -18,6 +18,7 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, Format, Recipe, ) @@ -25,11 +26,17 @@ from transformer_engine.pytorch import ( is_fp8_available, is_fp8_block_scaling_available, + is_nvfp4_available, QuantizedTensor, Float8Tensor, Float8BlockwiseQTensor, + NVFP4Tensor, ) -from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor import ( + cast_master_weights_to_fp8, + cast_master_weights_to_nvfp4, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data @@ -60,6 +67,12 @@ def _get_raw_data(quantized_tensor): quantized_tensor._rowwise_data.dtype == torch.uint8 ), "Float8BlockwiseQTensor _rowwise_data must be uint8" return quantized_tensor._rowwise_data + elif isinstance(quantized_tensor, NVFP4Tensor): + assert hasattr(quantized_tensor, "_rowwise_data"), "NVFP4Tensor missing _rowwise_data" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "NVFP4Tensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data else: raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") @@ -207,10 +220,19 @@ def step(self): # ----------------------------------------------------------------------------------------- # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight # ----------------------------------------------------------------------------------------- - if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + for weight in self.weights: + assert isinstance(weight, NVFP4Tensor) + cast_master_weights_to_nvfp4( + self.weights, + self.master_weights, + self.start_offsets, + self.dp_group, + ) + elif isinstance(first_weight, QuantizedTensor): + for weight in self.weights: + assert isinstance(weight, QuantizedTensor) cast_master_weights_to_fp8( self.weights, self.master_weights, @@ -412,8 +434,23 @@ def step(self): # Update the master weight using gradient descent master_weight -= grad * self.lr - # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], QuantizedTensor): + # Step 3: Cast master weights to quantized or BF16 precision + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + local_weights.append(local_weight) + cast_master_weights_to_nvfp4( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + elif isinstance(first_weight, QuantizedTensor): local_weights = [] for local_weight in self.local_weights: if local_weight is None: @@ -670,6 +707,84 @@ def _test_fsdp_cast_master_weights_to_fp8( torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) +def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processing): + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + nvfp4_recipe = NVFP4BlockScaling() + + with te.quantized_model_init( + enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True + ): + model_nvfp4 = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + model = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + high_precision_init_val = w_nvfp4.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad = torch.zeros_like(w_nvfp4, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_nvfp4 = MiniZero_1( + [w for w in model_nvfp4.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + x = inputs[rank] + + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y_nvfp4 = model_nvfp4(x) + + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + target = targets[rank] + loss_nvfp4 = nn.MSELoss()(y_nvfp4, target) + loss = nn.MSELoss()(y, target) + + loss_nvfp4.backward() + loss.backward() + + optimizer_nvfp4.step() + optimizer.step() + + torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) + + def run_parallel_tests() -> None: """Run parallel tests""" @@ -708,6 +823,11 @@ def run_parallel_tests() -> None: _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + nvfp4_available, _ = is_nvfp4_available(return_reason=True) + if nvfp4_available: + for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_nvfp4(dp_group, post_ag_processing) + dist.destroy_process_group() @@ -741,5 +861,39 @@ def main() -> None: run_parallel_tests() +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA." +) +def test_nvfp4_partial_cast_matches_full() -> None: + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + shape = (64, 64) + master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + nvfp4_tensor = quantizer(master_weight.to(torch.bfloat16)) + assert isinstance(nvfp4_tensor, NVFP4Tensor) + + reference_data = nvfp4_tensor._rowwise_data.detach().clone() + reference_scale = nvfp4_tensor._rowwise_scale_inv.detach().clone() + + nvfp4_tensor._rowwise_data.zero_() + nvfp4_tensor._rowwise_scale_inv.zero_() + + cast_master_weights_to_nvfp4( + [nvfp4_tensor], + [master_weight.clone()], + [0], + None, + ) + + torch.testing.assert_close(nvfp4_tensor._rowwise_data, reference_data) + torch.testing.assert_close(nvfp4_tensor._rowwise_scale_inv, reference_scale) + + if __name__ == "__main__": main() From b4a0084d6f0f86cfd86e6d4c41b16ac45cae3add Mon Sep 17 00:00:00 2001 From: qiyuw Date: Sat, 22 Nov 2025 01:48:32 +0000 Subject: [PATCH 06/40] fix build error --- transformer_engine/common/recipe/nvfp4.cu | 7 ++++--- transformer_engine/pytorch/csrc/extensions.h | 3 ++- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 0 3 files changed, 6 insertions(+), 4 deletions(-) mode change 100644 => 100755 transformer_engine/common/recipe/nvfp4.cu mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/pybind.cpp diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu old mode 100644 new mode 100755 index 6dea833aaf..dfd7edfc9d --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -164,10 +164,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const float tile_decode_scale = decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + constexpr float kFp32Max = 3.402823466e+38F; float tile_encode_val = (tile_decode_scale > 0.f) ? 1.0f / (tile_decode_scale * global_decode_scale) - : TypeExtrema::max; - tile_encode_val = fminf(tile_encode_val, TypeExtrema::max); + : kFp32Max; + tile_encode_val = fminf(tile_encode_val, kFp32Max); const float2 scale_vec = make_float2(tile_encode_val, tile_encode_val); bool skip_store = true; @@ -304,7 +305,7 @@ void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, const Tensor global_scale, size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, size_t block_len, - const DType out_dtype, cudaStream_t stream) { + cudaStream_t stream) { NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100644 new mode 100755 index 94f1c10761..f51a881e90 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -339,7 +339,8 @@ void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, si size_t start_offset, size_t block_len); void nvfp4_2d_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, - size_t h, size_t w, size_t start_offset, size_t block_len); + const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100644 new mode 100755 From 3776ac4f72a82ee55f2747213ad38eeed1657ed5 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 25 Nov 2025 05:18:50 +0000 Subject: [PATCH 07/40] fix some issues --- .../test_cast_master_weights_to_fp8.py | 12 ++--- transformer_engine/pytorch/csrc/extensions.h | 2 +- transformer_engine/pytorch/module/base.py | 44 +++---------------- 3 files changed, 13 insertions(+), 45 deletions(-) mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions.h diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index c746eeb852..abea5e7c9c 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -810,23 +810,23 @@ def run_parallel_tests() -> None: quantizations = [] if is_fp8_available(): + print("fp8 available") quantizations.extend(["fp8", "fp8_cs"]) if is_fp8_block_scaling_available(): quantizations.append("fp8_block") - manual_post_all_gather_processings = [False, True] - + print("starting mini optimizer test") _test_mini_optimizer(dp_group) - + print("starting cast master weights to fp8 test") for quantization in quantizations: for post_ag_processing in manual_post_all_gather_processings: _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) - + print("starting cast master weights to nvfp4 test") nvfp4_available, _ = is_nvfp4_available(return_reason=True) if nvfp4_available: - for post_ag_processing in manual_post_all_gather_processings: - _test_cast_master_weights_to_nvfp4(dp_group, post_ag_processing) + #for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_nvfp4(dp_group, False) dist.destroy_process_group() diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100755 new mode 100644 index f51a881e90..b7d6927b02 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -338,7 +338,7 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, size_t start_offset, size_t block_len); -void nvfp4_2d_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, +void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, size_t block_len); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5035de14c2..92de9ff4f0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -621,8 +621,6 @@ def __init__(self) -> None: self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - # NVFP4 reuses FP8 GlobalStateManager - self.primary_weights_in_nvfp4 = FP8GlobalStateManager.with_fp8_parameters() self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None @@ -1266,7 +1264,11 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] if quantizer is None: raise RuntimeError("Weight quantizer has not been initialized") - quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + columnwise_usage = torch.is_grad_enabled() + if isinstance(quantizer, NVFP4Quantizer) and quantizer.with_2d_quantization: + # NVFP4 2D stores only rowwise data/scale + columnwise_usage = False + quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) quantizer.internal = False if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): device_mesh = dtensor_param.device_mesh @@ -1279,7 +1281,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) - + # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already @@ -1330,40 +1332,6 @@ def clear(self): else: setattr(self, name, dtensor_param) - # NVFP4 primary weights path - if self.primary_weights_in_nvfp4 and fp8_meta_index is not None: - - high_precision_init_val = None - if self.preserve_high_precision_init_val: - high_precision_init_val = getattr(self, name).detach().cpu() - - quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - if quantizer is None: - raise RuntimeError("Weight quantizer has not been initialized") - if not isinstance(quantizer, NVFP4Quantizer): - # Skip if this meta index is not NVFP4 (e.g., activations) - pass - else: - quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - quantizer.internal = False - nvfp4_param = quantizer(getattr(self, name)) - nvfp4_param = torch.nn.Parameter(nvfp4_param) - if high_precision_init_val is not None: - def get(self_): - if hasattr(self_, "_high_precision_init_val"): - return self_._high_precision_init_val - return None - - def clear(self_): - if hasattr(self_, "_high_precision_init_val"): - del self_._high_precision_init_val - - nvfp4_param._high_precision_init_val = high_precision_init_val - nvfp4_param.get_high_precision_init_val = MethodType(get, nvfp4_param) - nvfp4_param.clear_high_precision_init_val = MethodType(clear, nvfp4_param) - - setattr(self, name, nvfp4_param) - @abstractmethod def forward(self): """Needs override.""" From 7787e0fe1fabb89759f1f03f438c146c6773687a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Sun, 30 Nov 2025 23:54:28 +0000 Subject: [PATCH 08/40] use larger problem for cublas gemm,fix some issues --- .../test_cast_master_weights_to_fp8.py | 67 +++++++++++++------ .../common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/pytorch/module/base.py | 6 +- 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index abea5e7c9c..fabbca5f5b 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -124,7 +124,7 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.offsets = [0] for weight in self.weights: self.offsets.append(self.offsets[-1] + weight.numel()) - + print(f"offsets: {self.offsets}") # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may # not be the end range of the last weight. if self.offsets[-1] % self.world_size != 0: @@ -139,7 +139,7 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals # The start and end of this rank's local buffer in the global buffer rank_start = self.offsets[-1] // self.world_size * self.rank rank_end = rank_start + self.offsets[-1] // self.world_size - + print(f"current rank: {self.rank}, rank_start: {rank_start}, rank_end: {rank_end}") for weight, offset in zip(self.weights, self.offsets[:-1]): if offset >= rank_end or (offset + weight.numel()) <= rank_start: # This weight is not in this rank's local buffer @@ -265,6 +265,19 @@ def step(self): weight = self.weights[i] weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] overlapping_start, overlapping_end = self.overlapping_areas[i] + buffer_len = overlapping_end - overlapping_start + slice_len = weight_slice.numel() + if buffer_len != slice_len: + print( + "[MiniZero_1] copy mismatch:", + f"idx={i}", + f"buffer_len={buffer_len}", + f"slice_len={slice_len}", + f"weight_shape={tuple(weight.shape)}", + f"start_offset={start_offset}", + f"master_numel={master_weight.numel()}", + f"overlap=({overlapping_start},{overlapping_end})", + ) self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) # ----------------------------------------------------------------------------------------- @@ -282,6 +295,16 @@ def step(self): end = offset + weight.numel() if isinstance(weight, QuantizedTensor): weight = _get_raw_data(weight) + buffer_len = end - start + slice_len = weight.view(-1).numel() + if slice_len != buffer_len: + print( + "[MiniZero_1] gather mismatch:", + f"buffer_len={buffer_len}", + f"slice_len={slice_len}", + f"weight_shape={tuple(weight.shape)}", + f"offset=({start},{end})", + ) weight.view(-1).data.copy_(self.weight_buffer[start:end]) if self.manual_post_all_gather_processing: @@ -728,14 +751,14 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True ): model_nvfp4 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) @@ -758,7 +781,7 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi w.main_grad.zero_() inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + torch.randn(128, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] x = inputs[rank] @@ -779,9 +802,9 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi loss_nvfp4.backward() loss.backward() - optimizer_nvfp4.step() optimizer.step() - + optimizer_nvfp4.step() + torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) @@ -809,19 +832,19 @@ def run_parallel_tests() -> None: dp_group = dist.new_group(backend="nccl") quantizations = [] - if is_fp8_available(): - print("fp8 available") - quantizations.extend(["fp8", "fp8_cs"]) - if is_fp8_block_scaling_available(): - quantizations.append("fp8_block") - manual_post_all_gather_processings = [False, True] - print("starting mini optimizer test") - _test_mini_optimizer(dp_group) - print("starting cast master weights to fp8 test") - for quantization in quantizations: - for post_ag_processing in manual_post_all_gather_processings: - _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) - _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + # if is_fp8_available(): + # print("fp8 available") + # quantizations.extend(["fp8", "fp8_cs"]) + # if is_fp8_block_scaling_available(): + # quantizations.append("fp8_block") + # manual_post_all_gather_processings = [False, True] + # print("starting mini optimizer test") + # _test_mini_optimizer(dp_group) + # print("starting cast master weights to fp8 test") + # for quantization in quantizations: + # for post_ag_processing in manual_post_all_gather_processings: + # _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + # _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) print("starting cast master weights to nvfp4 test") nvfp4_available, _ = is_nvfp4_available(return_reason=True) if nvfp4_available: diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 97e8ec9a3e..ed4be7acd9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -737,7 +737,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(new_workspace_alignment % 256 == 0, "cuBLAS workspace pointer must be aligned to 256 bytes, got ", new_workspace_alignment); - + const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 92de9ff4f0..db9e8bd0a6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1264,11 +1264,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] if quantizer is None: raise RuntimeError("Weight quantizer has not been initialized") - columnwise_usage = torch.is_grad_enabled() - if isinstance(quantizer, NVFP4Quantizer) and quantizer.with_2d_quantization: - # NVFP4 2D stores only rowwise data/scale - columnwise_usage = False - quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): device_mesh = dtensor_param.device_mesh From 7fbf724a5c7fad5bf431951629d3585031aaf8d9 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 8 Dec 2025 06:12:15 +0000 Subject: [PATCH 09/40] debug code --- .../test_cast_master_weights_to_fp8.py | 173 ++++++++++++++++-- transformer_engine/pytorch/tensor/utils.py | 120 +++++++++--- 2 files changed, 254 insertions(+), 39 deletions(-) mode change 100644 => 100755 tests/pytorch/distributed/test_cast_master_weights_to_fp8.py mode change 100644 => 100755 transformer_engine/pytorch/tensor/utils.py diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py old mode 100644 new mode 100755 index fabbca5f5b..8b29f94ee2 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -130,16 +130,43 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals if self.offsets[-1] % self.world_size != 0: self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + self.weights_are_nvfp4 = isinstance(self.weights[0], NVFP4Tensor) + + # Storage offsets operate on the packed representation (e.g., NVFP4 uint8 data). + self.storage_offsets = None + self.storage_sizes = None + self.storage_total = None + if self.weights_are_nvfp4: + self.storage_offsets = [0] + self.storage_sizes = [] + for weight in self.weights: + storage_size = _get_raw_data(weight).view(-1).numel() + self.storage_sizes.append(storage_size) + self.storage_offsets.append(self.storage_offsets[-1] + storage_size) + if self.storage_offsets[-1] % self.world_size != 0: + self.storage_offsets[-1] += ( + self.world_size - self.storage_offsets[-1] % self.world_size + ) + self.storage_total = self.storage_offsets[-1] + self.master_weights = [] # The start offset of the master weight in the weight self.start_offsets = [] # The overlapping area of the weight and this rank's local buffer self.overlapping_areas = [] + # Storage equivalents (only populated for NVFP4 tensors). + self.storage_start_offsets = [None] * len(self.weights) + self.storage_overlapping_areas = [None] * len(self.weights) # The start and end of this rank's local buffer in the global buffer rank_start = self.offsets[-1] // self.world_size * self.rank rank_end = rank_start + self.offsets[-1] // self.world_size print(f"current rank: {self.rank}, rank_start: {rank_start}, rank_end: {rank_end}") + storage_rank_start = None + storage_rank_end = None + if self.weights_are_nvfp4: + storage_rank_start = self.storage_total // self.world_size * self.rank + storage_rank_end = storage_rank_start + self.storage_total // self.world_size for weight, offset in zip(self.weights, self.offsets[:-1]): if offset >= rank_end or (offset + weight.numel()) <= rank_start: # This weight is not in this rank's local buffer @@ -167,6 +194,17 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.start_offsets.append(start_offset) self.overlapping_areas.append(overlapping_area) + if self.weights_are_nvfp4: + for idx, (weight, storage_offset, storage_size) in enumerate( + zip(self.weights, self.storage_offsets[:-1], self.storage_sizes) + ): + if storage_offset >= storage_rank_end or (storage_offset + storage_size) <= storage_rank_start: + continue + overlap_start = max(storage_rank_start, storage_offset) + overlap_end = min(storage_rank_end, storage_offset + storage_size) + self.storage_start_offsets[idx] = overlap_start - storage_offset + self.storage_overlapping_areas[idx] = (overlap_start, overlap_end) + # Create global buffer for grads reduce-scatter self.grad_buffer = torch.empty( [self.offsets[-1]], dtype=torch.float32, device=weights[0].device @@ -176,12 +214,23 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals # Create global buffer for weights all-gather if isinstance(self.weights[0], QuantizedTensor): weight_buffer_dtype = torch.uint8 + if self.weights_are_nvfp4: + weight_buffer_length = self.storage_total + buffer_rank_start = storage_rank_start + buffer_rank_end = storage_rank_end + else: + weight_buffer_length = self.offsets[-1] + buffer_rank_start = rank_start + buffer_rank_end = rank_end else: weight_buffer_dtype = weights[0].dtype + weight_buffer_length = self.offsets[-1] + buffer_rank_start = rank_start + buffer_rank_end = rank_end self.weight_buffer = torch.empty( - [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + [weight_buffer_length], dtype=weight_buffer_dtype, device=weights[0].device ) - self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + self.weight_buffer_slice = self.weight_buffer[buffer_rank_start:buffer_rank_end] def step(self): # ----------------------------------------------------------------------------------------- @@ -259,6 +308,30 @@ def step(self): if master_weight is None: continue start_offset = self.start_offsets[i] + if isinstance(self.weights[i], NVFP4Tensor): + storage_start = self.storage_start_offsets[i] + storage_overlap = self.storage_overlapping_areas[i] + if storage_start is None or storage_overlap is None: + continue + weight = _get_raw_data(self.weights[i]).view(-1) + storage_len = storage_overlap[1] - storage_overlap[0] + weight_slice = weight[storage_start : storage_start + storage_len] + overlapping_start, overlapping_end = storage_overlap + buffer_len = overlapping_end - overlapping_start + slice_len = weight_slice.numel() + if buffer_len != slice_len: + print( + "[MiniZero_1] copy mismatch:", + f"idx={i}", + f"buffer_len={buffer_len}", + f"slice_len={slice_len}", + f"weight_shape={tuple(weight.shape)}", + f"storage_start={storage_start}", + f"storage_len={storage_len}", + f"overlap=({overlapping_start},{overlapping_end})", + ) + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + continue if isinstance(self.weights[i], QuantizedTensor): weight = _get_raw_data(self.weights[i]) else: @@ -290,22 +363,30 @@ def step(self): # ----------------------------------------------------------------------------------------- # Step 7: Copy the gathered weights from weight buffer to the actual weights # ----------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - weight = _get_raw_data(weight) + for idx, weight in enumerate(self.weights): + if isinstance(weight, NVFP4Tensor): + start = self.storage_offsets[idx] + end = start + self.storage_sizes[idx] + weight_data = _get_raw_data(weight) buffer_len = end - start - slice_len = weight.view(-1).numel() - if slice_len != buffer_len: + slice_len = weight_data.view(-1).numel() + if slice_len != (end - start): print( "[MiniZero_1] gather mismatch:", f"buffer_len={buffer_len}", f"slice_len={slice_len}", - f"weight_shape={tuple(weight.shape)}", + f"weight_shape={tuple(weight_data.shape)}", f"offset=({start},{end})", ) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) + weight_data.view(-1).data.copy_(self.weight_buffer[start:end]) + continue + start = self.offsets[idx] + end = start + weight.numel() + if isinstance(weight, QuantizedTensor): + weight_data = _get_raw_data(weight) + else: + weight_data = weight + weight_data.view(-1).data.copy_(self.weight_buffer[start:end]) if self.manual_post_all_gather_processing: quantized_weights = [ @@ -888,35 +969,97 @@ def main() -> None: not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA." ) def test_nvfp4_partial_cast_matches_full() -> None: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + dp_group = dist.new_group(backend="nccl") available, reason = is_nvfp4_available(return_reason=True) if not available: pytest.skip(reason) torch.manual_seed(1234) device = torch.device("cuda") - shape = (64, 64) + shape = (2048, 64) master_weight = torch.randn(shape, dtype=torch.float32, device=device) quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) nvfp4_tensor = quantizer(master_weight.to(torch.bfloat16)) + print( + "reference nvfp4_tensor._rowwise_data and shape", + nvfp4_tensor._rowwise_data, + nvfp4_tensor._rowwise_data.shape, + ) + print( + "reference nvfp4_tensor._rowwise_scale_inv and shape", + nvfp4_tensor._rowwise_scale_inv, + nvfp4_tensor._rowwise_scale_inv.shape, + ) assert isinstance(nvfp4_tensor, NVFP4Tensor) reference_data = nvfp4_tensor._rowwise_data.detach().clone() reference_scale = nvfp4_tensor._rowwise_scale_inv.detach().clone() + reference_dequant = nvfp4_tensor.dequantize(dtype=torch.bfloat16).detach().clone() + print( + "reference_scale first rows", + reference_scale[:16, :4].tolist(), + ) + # Build a layout-matched reference by reusing the quantizer with an explicitly allocated tensor. + layout_matched_reference = quantizer.make_empty( + shape, dtype=torch.bfloat16, device=device + ) + layout_matched_reference.zero_() + quantizer.update_quantized( + master_weight.to(torch.bfloat16).view(1, -1), + layout_matched_reference, + ) + layout_reference_data = layout_matched_reference._rowwise_data.detach().clone() + layout_reference_scale = layout_matched_reference._rowwise_scale_inv.detach().clone() + layout_reference_rowwise_amax = layout_matched_reference._amax_rowwise.detach().clone() + nvfp4_tensor._rowwise_data.zero_() nvfp4_tensor._rowwise_scale_inv.zero_() + if nvfp4_tensor._amax_rowwise is not None: + print("nvfp4_tensor._amax_rowwise", nvfp4_tensor._amax_rowwise) + nvfp4_tensor._amax_rowwise.zero_() cast_master_weights_to_nvfp4( [nvfp4_tensor], [master_weight.clone()], [0], - None, + dp_group, ) - torch.testing.assert_close(nvfp4_tensor._rowwise_data, reference_data) + print("partial cast nvfp4_tensor", nvfp4_tensor) + print("partial cast nvfp4_tensor._rowwise_data and shape", nvfp4_tensor._rowwise_data, nvfp4_tensor._rowwise_data.shape) + print("partial cast nvfp4_tensor._rowwise_scale_inv and shape", nvfp4_tensor._rowwise_scale_inv, nvfp4_tensor._rowwise_scale_inv.shape) + print("partial cast nvfp4_tensor.dequantize(dtype=torch.bfloat16)", nvfp4_tensor.dequantize(dtype=torch.bfloat16)) + print("reference_dequant", reference_dequant) + print("partial cast nvfp4_tensor._amax_rowwise", nvfp4_tensor._amax_rowwise) + torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_rowwise_amax) + # torch.testing.assert_close( + # nvfp4_tensor.dequantize(dtype=torch.bfloat16), reference_dequant + # ) + # torch.testing.assert_close(nvfp4_tensor._rowwise_data, layout_reference_data) torch.testing.assert_close(nvfp4_tensor._rowwise_scale_inv, reference_scale) + # torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_global_amax) if __name__ == "__main__": - main() + test_nvfp4_partial_cast_matches_full() + #main() diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100644 new mode 100755 index 845b950e51..4b2a58b6eb --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -5,6 +5,7 @@ """Helper functions for using fp8/nvfp4 tensors as weights""" from typing import Optional, Union, List +import math import torch import transformer_engine_torch as tex @@ -162,7 +163,6 @@ def cast_master_weights_to_nvfp4( """Helper to cast master weights to NVFP4 primary weights.""" nvfp4_params = [] - if fsdp_shard_model_weights is None: use_fsdp_shard_model_weights = False fsdp_shard_model_weights = [None] * len(model_weights) @@ -187,7 +187,6 @@ def cast_master_weights_to_nvfp4( raise ValueError( f"cast_master_weights_to_nvfp4 only supports NVFP4 tensors, got {type(model_weight)}" ) - if len(nvfp4_params) > 0: _cast_master_weights_to_nvfp4_2d( nvfp4_params, group, use_fsdp_shard_model_weights=use_fsdp_shard_model_weights @@ -490,7 +489,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( params, scales ): if not manual_post_all_gather_processing: - # Clear columnwise data for all model weights. + # Reset transpose cache for all model weights. # We cannot create columnwise data here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated at this moment. @@ -525,47 +524,53 @@ def _cast_master_weights_to_nvfp4_2d( group. use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. """ - + device = params[0][0].device block_len = NVFP4_BLOCK_SCALING_SIZE dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device) cu_amax_sizes = [0] - scale_shapes: List[tuple[int, int]] = [] + tile_shapes: List[tuple[int, int]] = [] + row_sizes: List[int] = [] + tile_widths: List[int] = [] + scale_targets: List[torch.Tensor] = [] + amax_targets: List[Optional[torch.Tensor]] = [] for model_weight, _, _, _ in params: quantizer = model_weight._get_quantizer() assert isinstance(quantizer, NVFP4Quantizer) assert quantizer.with_2d_quantization, "NVFP4 2D quantization must be enabled." - scale_shape = quantizer.get_scale_shape(model_weight.shape, columnwise=False) - scale_shapes.append(scale_shape) - num_amaxes = scale_shape[0] * scale_shape[1] + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + tile_h = (h + block_len - 1) // block_len + tile_w = (w + block_len - 1) // block_len + tile_shapes.append((tile_h, tile_w)) + row_sizes.append(h) + tile_widths.append(tile_w) + scale_targets.append(model_weight._rowwise_scale_inv) + amax_targets.append(model_weight._amax_rowwise) + num_amaxes = tile_h * tile_w cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes) packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) amaxes: List[torch.Tensor] = [] scales: List[torch.Tensor] = [] - scale_inv_tmp: List[torch.Tensor] = [] - fp8_scale_targets: List[torch.Tensor] = [] global_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) global_amax_views: List[torch.Tensor] = [ global_amaxes[i : i + 1] for i in range(len(params)) ] for i, (model_weight, master_weight, start_offset, _) in enumerate(params): - scale_shape = scale_shapes[i] + scale_shape = tile_shapes[i] amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) - inv_tmp = torch.empty_like(scale) global_amax_view = global_amax_views[i] assert model_weight._rowwise_scale_inv is not None amaxes.append(amax) scales.append(scale) - scale_inv_tmp.append(inv_tmp) - fp8_scale_targets.append(model_weight._rowwise_scale_inv) if master_weight is not None and master_weight.numel() > 0: assert len(model_weight.shape) == 2 @@ -580,6 +585,10 @@ def _cast_master_weights_to_nvfp4_2d( if global_amaxes.numel() > 0: torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + print( + "[NVFP4 partial cast] global_amaxes:", + [(idx, float(val)) for idx, val in enumerate(global_amaxes.tolist())], + ) global_scale_tensor = global_amaxes.clone() if len(amaxes) > 0: @@ -594,34 +603,51 @@ def _cast_master_weights_to_nvfp4_2d( global_amaxes > 0, global_encode_scales, torch.ones_like(global_encode_scales) ) global_scale_tensor.copy_(global_encode_scales) + print( + "[NVFP4 partial cast] global_encode_scales:", + [(idx, float(val)) for idx, val in enumerate(global_encode_scales.tolist())], + ) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for amax_tensor, scale_tensor, inv_tmp_tensor, global_scale in zip( - amaxes, scales, scale_inv_tmp, global_scale_views + for amax_tensor, scale_tensor, global_scale in zip( + amaxes, scales, global_scale_views ): per_block_decode_scale = torch.clamp( (amax_tensor / fp4_max) * global_scale, max=finfo.max ) scale_tensor.copy_(per_block_decode_scale) - inv_tmp_tensor.copy_(per_block_decode_scale) - # Update _rowwise_scale_inv with reduced per-block decode scales. - for inv_tmp, target in zip(scale_inv_tmp, fp8_scale_targets): - fp8_view = inv_tmp.to(dtype=torch.float8_e4m3fn).view(torch.uint8) - target.copy_(fp8_view) else: global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for ( + zipped_meta = zip( + tile_shapes, + row_sizes, + tile_widths, + scale_targets, + amax_targets, + params, + scales, + global_scale_views, + ) + for idx, ( + tile_shape, + rows, + tile_col_cnt, + target_scale, + target_amax, (model_weight, master_weight, start_offset, model_weight_fragment), per_block_decode_scale, global_scale, - ) in zip(params, scales, global_scale_views): + ) in enumerate(zipped_meta): if master_weight is None or master_weight.numel() == 0: continue end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: - model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] + rowwise_bytes = model_weight._rowwise_data.view(-1) + byte_start = start_offset // 2 + byte_end = (end_offset + 1) // 2 + model_weight_fragment = rowwise_bytes[byte_start:byte_end] assert len(model_weight.shape) == 2 h, w = model_weight.shape tex.nvfp4_2d_partial_cast( @@ -634,6 +660,52 @@ def _cast_master_weights_to_nvfp4_2d( start_offset, block_len, ) + tile_rows = tile_shape[0] + expanded_scale = torch.zeros_like(target_scale, dtype=torch.float32) + chunk = block_len + for tile_row_idx in range(tile_rows): + base_row = tile_row_idx * chunk + row_end = min(base_row + chunk, rows) + if base_row >= target_scale.shape[0]: + break + expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[ + tile_row_idx + ] + fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) + target_scale.copy_(fp8_view) + if target_amax is not None: + target_amax.copy_(global_amaxes[idx : idx + 1]) + + if ( + master_weight is not None + and not use_fsdp_shard_model_weights + and isinstance(model_weight, NVFP4Tensor) + ): + quantizer = model_weight._get_quantizer() + reference_tensor = quantizer( + master_weight.detach() + .view(model_weight.shape) + .to(dtype=model_weight.dtype) + ) + ref_data = reference_tensor._rowwise_data + ref_scale = reference_tensor._rowwise_scale_inv + data_diff = (model_weight._rowwise_data != ref_data).nonzero(as_tuple=False) + scale_diff = ( + model_weight._rowwise_scale_inv != ref_scale + ).nonzero(as_tuple=False) + # print("model_weight._rowwise_scale_inv", model_weight._rowwise_scale_inv) + # print("ref_scale", ref_scale) + # print("scale_diff", scale_diff) + # if data_diff.numel() > 0: + # print( + # f"[NVFP4 partial cast][debug] data mismatch idx {idx}, first entries:", + # data_diff[:5].tolist(), + # ) + # if scale_diff.numel() > 0: + # print( + # f"[NVFP4 partial cast][debug] scale mismatch idx {idx}, first entries:", + # scale_diff[:5].tolist(), + # ) def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ From a2465d17af4271cd5a28fe9563cbd4554ef90db2 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Mon, 8 Dec 2025 21:35:10 +0800 Subject: [PATCH 10/40] Fix partial amax kernel and Add clear transpose cache --- transformer_engine/common/recipe/nvfp4.cu | 24 ++++++++-------------- transformer_engine/pytorch/tensor/utils.py | 9 ++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index dfd7edfc9d..5b3bbc76d0 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -87,30 +87,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t h, const size_t w, const size_t start_offset, const size_t len) { constexpr int kThreadsPerWarp = 32; - constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; - constexpr int kLoopsPerCol = (kTileDim + kNumWarps - 1) / kNumWarps; + static_assert(kTileDim * kTileDim == kThreadsPerBlock); - const int tile_col = blockIdx.x; - const int tile_row = blockIdx.y; + const size_t tile_col = blockIdx.x; + const size_t tile_row = blockIdx.y; const size_t end_offset = start_offset + len; const IType *input_minus_offset = input - start_offset; __shared__ float smem[kNumWarps]; float amax = 0.0f; - for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) { - size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp; - for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) { - size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp); - size_t idx = r * w + c; - if (r < h && c < w && idx >= start_offset && idx < end_offset) { - float other_amax = fabs(static_cast(input_minus_offset[idx])); - __builtin_assume(amax >= 0); - __builtin_assume(other_amax >= 0); - amax = fmaxf(amax, other_amax); - } - } + size_t r = tile_row * kTileDim + threadIdx.x / kTileDim; + size_t c = tile_col * kTileDim + threadIdx.x % kTileDim; + size_t idx = r * w + c; + if (r < h && c < w && idx >= start_offset && idx < end_offset) { + amax = fabs(static_cast(input_minus_offset[idx])); } for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 4b2a58b6eb..a5cb25bfcc 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -639,6 +639,15 @@ def _cast_master_weights_to_nvfp4_2d( per_block_decode_scale, global_scale, ) in enumerate(zipped_meta): + + # TODO: Add `manual_post_all_gather_processing` flag to control whether to reset transpose + # cache. + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + if master_weight is None or master_weight.numel() == 0: continue From c3c6aa7f10fb31089b3c94b5b9bcc1f9b9bc323e Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 8 Dec 2025 21:08:31 +0000 Subject: [PATCH 11/40] create colwise --- .../tensor/storage/nvfp4_tensor_storage.py | 29 +++++++++++++++++++ transformer_engine/pytorch/tensor/utils.py | 17 +++++------ 2 files changed, 37 insertions(+), 9 deletions(-) mode change 100755 => 100644 transformer_engine/pytorch/tensor/utils.py diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 04ab092ee2..5063244f9e 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -318,3 +318,32 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None + + def _create_columnwise(self): + """ + Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. + """ + rowwise_data = self._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + self._columnwise_data = tex.fp8_transpose( + rowwise_data, self._fp8_dtype, out=self._columnwise_data + ) + if self._columnwise_scale_inv is None: + assert self._quantizer is not None, ( + "._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise " + "quantized tensors are supposed to be generated from the quantizer." + ) + columnwise_scale_inv_shape = self._quantizer.get_scale_shape(rowwise_data.shape, True) + self._columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=self._rowwise_scale_inv.dtype, + device=self._rowwise_scale_inv.device, + ) + assert len(self._rowwise_scale_inv.shape) == 2 + assert len(self._columnwise_scale_inv.shape) == 2 + rowwise_scale_inv = self._rowwise_scale_inv + columnwise_scale_inv = rowwise_scale_inv.transpose(-2, -1) + h = min(self._columnwise_scale_inv.shape[0], columnwise_scale_inv.shape[0]) + w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1]) + self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w]) \ No newline at end of file diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100755 new mode 100644 index a5cb25bfcc..3f5232455b --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -640,13 +640,12 @@ def _cast_master_weights_to_nvfp4_2d( global_scale, ) in enumerate(zipped_meta): - # TODO: Add `manual_post_all_gather_processing` flag to control whether to reset transpose - # cache. - # Reset transpose cache for all model weights. - # We cannot create transpose cache here because users (like megatron) may want to - # overlap the all-gather of model weights and forward process, so the model weight is - # not updated currently. - model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) if master_weight is None or master_weight.numel() == 0: continue @@ -735,8 +734,8 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() elif isinstance(model_weight, NVFP4Tensor): - # TODO: Add create_columnwise for NVFP4Tensor - pass + # NVFP4 scaling: create column-wise storage. + model_weight._create_columnwise() elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") From 3222d48957c79c3404d815f3897c5c8802efb5be Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 8 Dec 2025 23:20:16 +0000 Subject: [PATCH 12/40] colwise usage fix --- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 7 +++++++ transformer_engine/pytorch/tensor/utils.py | 8 +++++--- 2 files changed, 12 insertions(+), 3 deletions(-) mode change 100644 => 100755 transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py mode change 100644 => 100755 transformer_engine/pytorch/tensor/utils.py diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py old mode 100644 new mode 100755 index 5063244f9e..15e925c5ee --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -318,6 +318,13 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None + + if rowwise_usage and columnwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage because rowwise data is None." + if self._columnwise_data is None or self._columnwise_scale_inv is None: + self._create_columnwise() def _create_columnwise(self): """ diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100644 new mode 100755 index 3f5232455b..775177c037 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -158,7 +158,8 @@ def cast_master_weights_to_fp8( def cast_master_weights_to_nvfp4( - model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False ): """Helper to cast master weights to NVFP4 primary weights.""" @@ -189,7 +190,8 @@ def cast_master_weights_to_nvfp4( ) if len(nvfp4_params) > 0: _cast_master_weights_to_nvfp4_2d( - nvfp4_params, group, use_fsdp_shard_model_weights=use_fsdp_shard_model_weights + nvfp4_params, group, use_fsdp_shard_model_weights=use_fsdp_shard_model_weights, + manual_post_all_gather_processing=manual_post_all_gather_processing ) def _cast_master_weights_to_fp8_delayed_scaling( @@ -512,7 +514,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # revisit this later def _cast_master_weights_to_nvfp4_2d( - params, group, use_fsdp_shard_model_weights=False + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. From a6727cff37aec97b8f6a562576ea15979e685b0b Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 9 Dec 2025 21:08:25 +0800 Subject: [PATCH 13/40] Fix partial cast numerical bug --- .../pytorch/distributed/test_cast_master_weights_to_fp8.py | 7 ++++--- transformer_engine/common/recipe/nvfp4.cu | 3 ++- transformer_engine/pytorch/tensor/utils.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 8b29f94ee2..38b9da992f 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1025,7 +1025,7 @@ def test_nvfp4_partial_cast_matches_full() -> None: ) layout_matched_reference.zero_() quantizer.update_quantized( - master_weight.to(torch.bfloat16).view(1, -1), + master_weight.to(torch.bfloat16), layout_matched_reference, ) layout_reference_data = layout_matched_reference._rowwise_data.detach().clone() @@ -1051,13 +1051,14 @@ def test_nvfp4_partial_cast_matches_full() -> None: print("partial cast nvfp4_tensor.dequantize(dtype=torch.bfloat16)", nvfp4_tensor.dequantize(dtype=torch.bfloat16)) print("reference_dequant", reference_dequant) print("partial cast nvfp4_tensor._amax_rowwise", nvfp4_tensor._amax_rowwise) - torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_rowwise_amax) + torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_rowwise_amax, atol=0, rtol=0) # torch.testing.assert_close( # nvfp4_tensor.dequantize(dtype=torch.bfloat16), reference_dequant # ) # torch.testing.assert_close(nvfp4_tensor._rowwise_data, layout_reference_data) - torch.testing.assert_close(nvfp4_tensor._rowwise_scale_inv, reference_scale) + torch.testing.assert_close(nvfp4_tensor._rowwise_scale_inv, reference_scale, atol=0, rtol=0) # torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_global_amax) + torch.testing.assert_close(nvfp4_tensor._rowwise_data, layout_reference_data, atol=0, rtol=0) if __name__ == "__main__": diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 5b3bbc76d0..0c06e4b18d 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -154,8 +154,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } const float global_decode_scale = 1.0f / global_encode_scale; - const float tile_decode_scale = + float tile_decode_scale = decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + tile_decode_scale = static_cast(static_cast(tile_decode_scale)); constexpr float kFp32Max = 3.402823466e+38F; float tile_encode_val = (tile_decode_scale > 0.f) ? 1.0f / (tile_decode_scale * global_decode_scale) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index a5cb25bfcc..4d7f311383 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -613,7 +613,7 @@ def _cast_master_weights_to_nvfp4_2d( amaxes, scales, global_scale_views ): per_block_decode_scale = torch.clamp( - (amax_tensor / fp4_max) * global_scale, max=finfo.max + (amax_tensor * (1.0 / fp4_max)) * global_scale, max=finfo.max ) scale_tensor.copy_(per_block_decode_scale) else: From 481730f2f0604393cf6d6fa800f39bb0532cde49 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 9 Dec 2025 18:39:58 +0000 Subject: [PATCH 14/40] nvfp4 transpose --- .../include/transformer_engine/recipe.h | 11 ++ transformer_engine/common/recipe/nvfp4.cu | 102 ++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/csrc/extensions/transpose.cpp | 47 ++++++++ .../tensor/storage/nvfp4_tensor_storage.py | 5 +- 6 files changed, 167 insertions(+), 3 deletions(-) mode change 100644 => 100755 transformer_engine/common/include/transformer_engine/recipe.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/transpose.cpp diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h old mode 100644 new mode 100755 index 1fe9c81a1a..49afce1610 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -137,6 +137,17 @@ void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTE size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, size_t block_len, cudaStream_t stream); +/*! \brief Transpose NVFP4 packed data. + * + * Unlike FP8, NVFP4 packs two 4-bit values per byte. This function correctly + * handles the nibble repacking during transpose. + * + * \param[in] input Input tensor with packed FP4 data. Shape: [M, K/2] bytes. + * \param[out] output Output tensor with transposed packed data. Shape: [K, M/2] bytes. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 0c06e4b18d..0801125ee2 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -328,9 +328,111 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, NVTE_CHECK_CUDA(cudaGetLastError()); } +/* + * --------------------------------------------------------------------------- + * NVFP4 TRANSPOSE KERNEL + * + * Unlike FP8, NVFP4 packs two 4-bit values into each byte. A simple byte-wise + * transpose doesn't work because the packing changes: + * - Before transpose: elements [m, 2c] and [m, 2c+1] share a byte + * - After transpose: elements [k, 2*m_packed] and [k, 2*m_packed+1] share a byte + * which were originally [2*m_packed, k] and [2*m_packed+1, k] + * + * So we need to read from two consecutive rows of the input, extract the same + * nibble position from each, and pack them into one output byte. + * --------------------------------------------------------------------------- + */ + +__global__ void nvfp4_transpose_kernel(const uint8_t *input, uint8_t *output, const size_t M, + const size_t K) { + // Input: [M, K] logical, stored as [M, K/2] bytes + // Output: [K, M] logical, stored as [K, M/2] bytes + + const size_t K_packed = K / 2; // input packed width + const size_t M_packed = M / 2; // output packed width + + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; // k index [0, K) + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; // packed m index [0, M/2) + + if (out_row >= K || out_col >= M_packed) return; + + // The two logical M positions this output byte covers + const size_t m0 = out_col * 2; // for low nibble of output + const size_t m1 = out_col * 2 + 1; // for high nibble of output + const size_t k = out_row; + + // In the input, both elements are at the same packed column + const size_t in_col = k / 2; + const int nibble_idx = k & 1; // 0 = low nibble, 1 = high nibble + + // Read the two input bytes from consecutive rows + const uint8_t in_byte_0 = input[m0 * K_packed + in_col]; + const uint8_t in_byte_1 = input[m1 * K_packed + in_col]; + + // Extract the appropriate nibbles + uint8_t val0, val1; + if (nibble_idx == 0) { + val0 = in_byte_0 & 0x0Fu; + val1 = in_byte_1 & 0x0Fu; + } else { + val0 = (in_byte_0 >> 4) & 0x0Fu; + val1 = (in_byte_1 >> 4) & 0x0Fu; + } + + // Pack: val0 in low nibble, val1 in high nibble + const uint8_t out_byte = val0 | static_cast(val1 << 4); + + output[out_row * M_packed + out_col] = out_byte; +} + +void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { + // Input has logical shape [M, K], stored as [M, K/2] bytes + // Output has logical shape [K, M], stored as [K, M/2] bytes + + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 transpose input must be uint8."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 transpose output must be uint8."); + + // Get dimensions from packed storage + // input.shape = [M, K/2], so M = shape[0], K = shape[1] * 2 + const auto &in_shape = input.shape; + NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", in_shape.size(), "D."); + const size_t M = in_shape[0]; + const size_t K_packed = in_shape[1]; + const size_t K = K_packed * 2; + + // Output should be [K, M/2] + const size_t M_packed = M / 2; + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + const auto &out_shape = output.shape; + NVTE_CHECK(out_shape.size() == 2, "NVFP4 transpose expects 2D output."); + NVTE_CHECK(out_shape[0] == K && out_shape[1] == M_packed, + "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, + "], got [", out_shape[0], ", ", out_shape[1], "]."); + + if (M == 0 || K == 0) return; + + // Launch kernel + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((M_packed + kBlockDim - 1) / kBlockDim, (K + kBlockDim - 1) / kBlockDim); + + nvfp4_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M, K); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + stream); +} + void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100644 new mode 100755 index b7d6927b02..4f5aa37e0c --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -156,6 +156,8 @@ std::optional> te_general_grouped_gemm( at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); +at::Tensor nvfp4_transpose(at::Tensor input, std::optional output = std::nullopt); + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7b1d079649..3ffa7033b6 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -254,6 +254,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("nvfp4_transpose", &transformer_engine::pytorch::nvfp4_transpose, + "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), + py::arg("out"), py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp old mode 100644 new mode 100755 index 7dfdf99547..ad6c50f065 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -9,6 +9,8 @@ #include #include +#include + #include "../extensions.h" #include "pybind.h" @@ -52,6 +54,51 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { + init_extension(); + + // Input is packed FP4: logical [M, K] stored as [M, K/2] bytes + // Output is packed FP4: logical [K, M] stored as [K, M/2] bytes + const auto shape = getTensorShape(input); + NVTE_CHECK(shape.size() == 2, "NVFP4 transpose expects 2D input (packed storage)."); + + const size_t M = shape[0]; + const size_t K_packed = shape[1]; + const size_t K = K_packed * 2; // logical K + const size_t M_packed = M / 2; + + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + // Output shape: [K, M/2] + std::vector output_shape = {static_cast(K), static_cast(M_packed)}; + + // Output tensor + at::Tensor out; + if (output.has_value()) { + out = *output; + NVTE_CHECK(static_cast(out.size(0)) == K && + static_cast(out.size(1)) == M_packed, + "Output shape mismatch for NVFP4 transpose."); + } else { + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(output_shape, opts); + } + + // Return immediately if tensor is empty + if (M == 0 || K == 0) { + return out; + } + + // Call the NVFP4 transpose kernel + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, + DType::kByte); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, + DType::kByte); + nvte_nvfp4_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return out; +} + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 15e925c5ee..c6ef84c803 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -333,9 +333,8 @@ def _create_columnwise(self): rowwise_data = self._rowwise_data if not rowwise_data.is_contiguous(): rowwise_data = rowwise_data.contiguous() - self._columnwise_data = tex.fp8_transpose( - rowwise_data, self._fp8_dtype, out=self._columnwise_data - ) + # NVFP4 requires a specialized transpose that handles nibble repacking + self._columnwise_data = tex.nvfp4_transpose(rowwise_data, out=self._columnwise_data) if self._columnwise_scale_inv is None: assert self._quantizer is not None, ( "._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise " From 882debb7a7e311af1407ee9c824f86b039906cdb Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 9 Dec 2025 20:18:45 +0000 Subject: [PATCH 15/40] bitwise identical transpose and add multi-gpu test --- .../test_cast_master_weights_to_fp8.py | 187 +++++++++++++----- transformer_engine/common/recipe/nvfp4.cu | 6 +- .../tensor/storage/nvfp4_tensor_storage.py | 16 +- 3 files changed, 146 insertions(+), 63 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 38b9da992f..e767881aa9 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -965,10 +965,65 @@ def main() -> None: run_parallel_tests() +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="NVFP4 transpose test requires CUDA." +) +def test_nvfp4_transpose_kernel() -> None: + """Test that nvfp4_transpose kernel produces bitwise identical results to reference.""" + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + shape = (2048, 64) + master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + print("\n=== Testing NVFP4 transpose kernel ===") + + # Create reference with both rowwise and columnwise data + quantizer_with_colwise = NVFP4Quantizer( + rowwise=True, columnwise=True, with_2d_quantization=True + ) + reference_tensor = quantizer_with_colwise(master_weight.to(torch.bfloat16)) + assert reference_tensor._columnwise_data is not None, "Reference should have columnwise data" + reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() + print( + "reference columnwise_data shape:", + reference_columnwise_data.shape, + ) + + # Create tensor with only rowwise data, then call _create_columnwise() + quantizer_rowwise_only = NVFP4Quantizer( + rowwise=True, columnwise=False, with_2d_quantization=True + ) + test_tensor = quantizer_rowwise_only(master_weight.to(torch.bfloat16)) + assert test_tensor._columnwise_data is None, "Test tensor should not have columnwise data yet" + + # Now call _create_columnwise() which uses our nvfp4_transpose kernel + test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) + assert test_tensor._columnwise_data is not None, "Test tensor should have columnwise data after _create_columnwise()" + print( + "test_tensor columnwise_data shape after transpose:", + test_tensor._columnwise_data.shape, + ) + + # Compare columnwise data - should be bitwise identical + torch.testing.assert_close( + test_tensor._columnwise_data, + reference_columnwise_data, + atol=0, + rtol=0, + msg="NVFP4 transpose kernel produced different columnwise data than reference!", + ) + print("NVFP4 transpose kernel test PASSED!") + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA." ) def test_nvfp4_partial_cast_matches_full() -> None: + """Test multi-GPU partial cast: split master weight, partial cast on each rank, all-gather, compare.""" WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -994,73 +1049,99 @@ def test_nvfp4_partial_cast_matches_full() -> None: torch.manual_seed(1234) device = torch.device("cuda") - shape = (2048, 64) - master_weight = torch.randn(shape, dtype=torch.float32, device=device) - - quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) - nvfp4_tensor = quantizer(master_weight.to(torch.bfloat16)) - print( - "reference nvfp4_tensor._rowwise_data and shape", - nvfp4_tensor._rowwise_data, - nvfp4_tensor._rowwise_data.shape, - ) - print( - "reference nvfp4_tensor._rowwise_scale_inv and shape", - nvfp4_tensor._rowwise_scale_inv, - nvfp4_tensor._rowwise_scale_inv.shape, - ) - assert isinstance(nvfp4_tensor, NVFP4Tensor) + # Shape must be divisible by WORLD_SIZE for even splitting + # Also ensure dimensions are multiples of 16 for NVFP4 tiles + shape = (2048, 512) + total_elements = shape[0] * shape[1] + assert total_elements % WORLD_SIZE == 0, "Total elements must be divisible by WORLD_SIZE" - reference_data = nvfp4_tensor._rowwise_data.detach().clone() - reference_scale = nvfp4_tensor._rowwise_scale_inv.detach().clone() - reference_dequant = nvfp4_tensor.dequantize(dtype=torch.bfloat16).detach().clone() - print( - "reference_scale first rows", - reference_scale[:16, :4].tolist(), - ) + # Full master weight (same on all ranks due to same seed) + full_master_weight = torch.randn(shape, dtype=torch.float32, device=device) - # Build a layout-matched reference by reusing the quantizer with an explicitly allocated tensor. - layout_matched_reference = quantizer.make_empty( - shape, dtype=torch.bfloat16, device=device - ) - layout_matched_reference.zero_() - quantizer.update_quantized( - master_weight.to(torch.bfloat16), - layout_matched_reference, - ) - layout_reference_data = layout_matched_reference._rowwise_data.detach().clone() - layout_reference_scale = layout_matched_reference._rowwise_scale_inv.detach().clone() - layout_reference_rowwise_amax = layout_matched_reference._amax_rowwise.detach().clone() - + # Create reference using full quantization + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + reference_tensor = quantizer(full_master_weight.to(torch.bfloat16)) + reference_data = reference_tensor._rowwise_data.detach().clone() + reference_scale = reference_tensor._rowwise_scale_inv.detach().clone() + reference_amax = reference_tensor._amax_rowwise.detach().clone() + print(f"[Rank {WORLD_RANK}] reference_data shape: {reference_data.shape}") + print(f"[Rank {WORLD_RANK}] reference_scale shape: {reference_scale.shape}") + + # Split master weight evenly across ranks + shard_size = total_elements // WORLD_SIZE + start_offset = WORLD_RANK * shard_size + end_offset = start_offset + shard_size + master_weight_shard = full_master_weight.view(-1)[start_offset:end_offset].clone() + print(f"[Rank {WORLD_RANK}] shard: start_offset={start_offset}, end_offset={end_offset}, shard_size={shard_size}") + + # Create empty NVFP4 tensor for this rank (full shape, but we'll only fill our shard) + nvfp4_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) nvfp4_tensor._rowwise_data.zero_() nvfp4_tensor._rowwise_scale_inv.zero_() if nvfp4_tensor._amax_rowwise is not None: - print("nvfp4_tensor._amax_rowwise", nvfp4_tensor._amax_rowwise) nvfp4_tensor._amax_rowwise.zero_() + # Partial cast on each rank's shard cast_master_weights_to_nvfp4( [nvfp4_tensor], - [master_weight.clone()], - [0], + [master_weight_shard], + [start_offset], dp_group, ) - print("partial cast nvfp4_tensor", nvfp4_tensor) - print("partial cast nvfp4_tensor._rowwise_data and shape", nvfp4_tensor._rowwise_data, nvfp4_tensor._rowwise_data.shape) - print("partial cast nvfp4_tensor._rowwise_scale_inv and shape", nvfp4_tensor._rowwise_scale_inv, nvfp4_tensor._rowwise_scale_inv.shape) - print("partial cast nvfp4_tensor.dequantize(dtype=torch.bfloat16)", nvfp4_tensor.dequantize(dtype=torch.bfloat16)) - print("reference_dequant", reference_dequant) - print("partial cast nvfp4_tensor._amax_rowwise", nvfp4_tensor._amax_rowwise) - torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_rowwise_amax, atol=0, rtol=0) - # torch.testing.assert_close( - # nvfp4_tensor.dequantize(dtype=torch.bfloat16), reference_dequant - # ) - # torch.testing.assert_close(nvfp4_tensor._rowwise_data, layout_reference_data) - torch.testing.assert_close(nvfp4_tensor._rowwise_scale_inv, reference_scale, atol=0, rtol=0) - # torch.testing.assert_close(nvfp4_tensor._amax_rowwise, layout_reference_global_amax) - torch.testing.assert_close(nvfp4_tensor._rowwise_data, layout_reference_data, atol=0, rtol=0) + # All-gather the rowwise data (packed FP4 bytes) + # Each rank has the full tensor but only its shard is filled + # We need to all-gather the shards + rowwise_data_flat = nvfp4_tensor._rowwise_data.view(-1) + + # For NVFP4, 2 elements are packed per byte, so byte shard size is shard_size // 2 + byte_shard_size = shard_size // 2 + byte_start = WORLD_RANK * byte_shard_size + byte_end = byte_start + byte_shard_size + my_shard_bytes = rowwise_data_flat[byte_start:byte_end].contiguous() + + # Gather all shards + gathered_shards = [torch.empty_like(my_shard_bytes) for _ in range(WORLD_SIZE)] + dist.all_gather(gathered_shards, my_shard_bytes, group=dp_group) + + # Reconstruct the full rowwise data + gathered_data = torch.cat(gathered_shards, dim=0).view(reference_data.shape) + print(f"[Rank {WORLD_RANK}] gathered_data shape: {gathered_data.shape}") + + # Compare with reference + torch.testing.assert_close( + gathered_data, + reference_data, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Gathered rowwise data does not match reference!", + ) + print(f"[Rank {WORLD_RANK}] rowwise_data matches reference!") + + # Also verify scale matches (scale should be identical on all ranks after all-reduce) + torch.testing.assert_close( + nvfp4_tensor._rowwise_scale_inv, + reference_scale, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Scale does not match reference!", + ) + print(f"[Rank {WORLD_RANK}] scale matches reference!") + + # Verify amax matches + torch.testing.assert_close( + nvfp4_tensor._amax_rowwise, + reference_amax, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Amax does not match reference!", + ) + print(f"[Rank {WORLD_RANK}] amax matches reference!") + + print(f"[Rank {WORLD_RANK}] Multi-GPU NVFP4 partial cast test PASSED!") if __name__ == "__main__": + test_nvfp4_transpose_kernel() test_nvfp4_partial_cast_matches_full() #main() diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 0801125ee2..c028c3b13e 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -393,8 +393,8 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 transpose output must be uint8."); // Get dimensions from packed storage - // input.shape = [M, K/2], so M = shape[0], K = shape[1] * 2 - const auto &in_shape = input.shape; + // input.shape() = [M, K/2], so M = shape[0], K = shape[1] * 2 + const auto in_shape = input.shape(); NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", in_shape.size(), "D."); const size_t M = in_shape[0]; const size_t K_packed = in_shape[1]; @@ -404,7 +404,7 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { const size_t M_packed = M / 2; NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); - const auto &out_shape = output.shape; + const auto out_shape = output.shape(); NVTE_CHECK(out_shape.size() == 2, "NVFP4 transpose expects 2D output."); NVTE_CHECK(out_shape[0] == K && out_shape[1] == M_packed, "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index c6ef84c803..c955218b9a 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -278,6 +278,15 @@ def update_usage( if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None + # If both rowwise and columnwise are requested, create columnwise from rowwise if needed + if rowwise_usage and columnwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage because rowwise data is None." + if self._columnwise_data is None or self._columnwise_scale_inv is None: + self._create_columnwise() + return + # Update row-scaled data if rowwise_usage: if self._rowwise_data is None: @@ -318,13 +327,6 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None - - if rowwise_usage and columnwise_usage: - assert ( - self._rowwise_data is not None and self._rowwise_scale_inv is not None - ), "Cannot update to rowwise and columnwise usage because rowwise data is None." - if self._columnwise_data is None or self._columnwise_scale_inv is None: - self._create_columnwise() def _create_columnwise(self): """ From 21394cec611b7ac09cbe97344d2f06c8673b2f82 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 9 Dec 2025 22:51:04 +0000 Subject: [PATCH 16/40] clean up --- transformer_engine/pytorch/tensor/utils.py | 31 ---------------------- 1 file changed, 31 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 7e69dd8d51..ac375a4793 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -686,37 +686,6 @@ def _cast_master_weights_to_nvfp4_2d( if target_amax is not None: target_amax.copy_(global_amaxes[idx : idx + 1]) - if ( - master_weight is not None - and not use_fsdp_shard_model_weights - and isinstance(model_weight, NVFP4Tensor) - ): - quantizer = model_weight._get_quantizer() - reference_tensor = quantizer( - master_weight.detach() - .view(model_weight.shape) - .to(dtype=model_weight.dtype) - ) - ref_data = reference_tensor._rowwise_data - ref_scale = reference_tensor._rowwise_scale_inv - data_diff = (model_weight._rowwise_data != ref_data).nonzero(as_tuple=False) - scale_diff = ( - model_weight._rowwise_scale_inv != ref_scale - ).nonzero(as_tuple=False) - # print("model_weight._rowwise_scale_inv", model_weight._rowwise_scale_inv) - # print("ref_scale", ref_scale) - # print("scale_diff", scale_diff) - # if data_diff.numel() > 0: - # print( - # f"[NVFP4 partial cast][debug] data mismatch idx {idx}, first entries:", - # data_diff[:5].tolist(), - # ) - # if scale_diff.numel() > 0: - # print( - # f"[NVFP4 partial cast][debug] scale mismatch idx {idx}, first entries:", - # scale_diff[:5].tolist(), - # ) - def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ Post-processing after all-gather for weights in distributed optimizer. From aad419c2d020cbb01eb2605c0e1ed818a48e7a6a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 10 Dec 2025 06:05:22 +0000 Subject: [PATCH 17/40] fix create columnwise and add debugging and testing code --- .../test_cast_master_weights_to_fp8.py | 80 +++++++++++++++++-- .../tensor/storage/nvfp4_tensor_storage.py | 55 +++++++++++-- 2 files changed, 124 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index e767881aa9..146b7e71db 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -332,7 +332,7 @@ def step(self): ) self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) continue - if isinstance(self.weights[i], QuantizedTensor): + elif isinstance(self.weights[i], QuantizedTensor): weight = _get_raw_data(self.weights[i]) else: weight = self.weights[i] @@ -856,7 +856,7 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ) optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) - for _ in range(100): + for i in range(100): for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): w_nvfp4.main_grad.zero_() w.main_grad.zero_() @@ -866,6 +866,19 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ] x = inputs[rank] + # Debug: compare master weights before forward at iteration 8 + if i == 8 and rank == 0: + print(f"\n=== Debug iteration {i} ===") + for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): + # Compare master weights + master_nvfp4 = optimizer_nvfp4.master_weights[idx] + master_bf16 = optimizer.master_weights[idx] + master_match = torch.equal(master_nvfp4, master_bf16) + print(f"Layer {idx}: master weights match = {master_match}") + if not master_match: + diff = (master_nvfp4 - master_bf16).abs().max().item() + print(f" max diff = {diff}") + with te.autocast( enabled=True, recipe=nvfp4_recipe, @@ -873,7 +886,12 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ): y_nvfp4 = model_nvfp4(x) - y = model(x) + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y = model(x) targets = [torch.randn_like(y) for _ in range(world_size)] target = targets[rank] @@ -887,6 +905,7 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi optimizer_nvfp4.step() torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) + print("iter:", i, "loss matched") def run_parallel_tests() -> None: @@ -987,11 +1006,18 @@ def test_nvfp4_transpose_kernel() -> None: ) reference_tensor = quantizer_with_colwise(master_weight.to(torch.bfloat16)) assert reference_tensor._columnwise_data is not None, "Reference should have columnwise data" + assert reference_tensor._columnwise_scale_inv is not None, "Reference should have columnwise scale_inv" reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() + reference_columnwise_scale_inv = reference_tensor._columnwise_scale_inv.detach().clone() + reference_columnwise_amax = reference_tensor._amax_columnwise.detach().clone() if reference_tensor._amax_columnwise is not None else None print( "reference columnwise_data shape:", reference_columnwise_data.shape, ) + print( + "reference columnwise_scale_inv shape:", + reference_columnwise_scale_inv.shape, + ) # Create tensor with only rowwise data, then call _create_columnwise() quantizer_rowwise_only = NVFP4Quantizer( @@ -1003,10 +1029,15 @@ def test_nvfp4_transpose_kernel() -> None: # Now call _create_columnwise() which uses our nvfp4_transpose kernel test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) assert test_tensor._columnwise_data is not None, "Test tensor should have columnwise data after _create_columnwise()" + assert test_tensor._columnwise_scale_inv is not None, "Test tensor should have columnwise scale_inv after _create_columnwise()" print( "test_tensor columnwise_data shape after transpose:", test_tensor._columnwise_data.shape, ) + print( + "test_tensor columnwise_scale_inv shape after transpose:", + test_tensor._columnwise_scale_inv.shape, + ) # Compare columnwise data - should be bitwise identical torch.testing.assert_close( @@ -1016,6 +1047,43 @@ def test_nvfp4_transpose_kernel() -> None: rtol=0, msg="NVFP4 transpose kernel produced different columnwise data than reference!", ) + print("columnwise_data matches!") + + # Compare columnwise scale_inv - should be bitwise identical + print("reference columnwise_scale_inv:\n", reference_columnwise_scale_inv) + print("test columnwise_scale_inv:\n", test_tensor._columnwise_scale_inv) + print("reference rowwise_scale_inv shape:", reference_tensor._rowwise_scale_inv.shape) + print("test rowwise_scale_inv shape:", test_tensor._rowwise_scale_inv.shape) + + # Check if they match + scale_match = torch.equal(test_tensor._columnwise_scale_inv, reference_columnwise_scale_inv) + if not scale_match: + diff_mask = test_tensor._columnwise_scale_inv != reference_columnwise_scale_inv + print("Number of mismatches:", diff_mask.sum().item()) + print("Mismatch locations:", diff_mask.nonzero()[:10]) + print("Test values at mismatch:", test_tensor._columnwise_scale_inv[diff_mask][:10]) + print("Reference values at mismatch:", reference_columnwise_scale_inv[diff_mask][:10]) + + torch.testing.assert_close( + test_tensor._columnwise_scale_inv, + reference_columnwise_scale_inv, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise scale_inv than reference!", + ) + print("columnwise_scale_inv matches!") + + # Compare columnwise amax if available + if reference_columnwise_amax is not None: + torch.testing.assert_close( + test_tensor._amax_columnwise, + reference_columnwise_amax, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", + ) + print("columnwise_amax matches!") + print("NVFP4 transpose kernel test PASSED!") @@ -1142,6 +1210,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: if __name__ == "__main__": - test_nvfp4_transpose_kernel() - test_nvfp4_partial_cast_matches_full() - #main() + #test_nvfp4_transpose_kernel() + #test_nvfp4_partial_cast_matches_full() + main() diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index c955218b9a..0fb24fb812 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -342,7 +342,10 @@ def _create_columnwise(self): "._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise " "quantized tensors are supposed to be generated from the quantizer." ) - columnwise_scale_inv_shape = self._quantizer.get_scale_shape(rowwise_data.shape, True) + # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) + # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 + logical_shape = self.size() + columnwise_scale_inv_shape = self._quantizer.get_scale_shape(logical_shape, True) self._columnwise_scale_inv = torch.empty( columnwise_scale_inv_shape, dtype=self._rowwise_scale_inv.dtype, @@ -350,8 +353,50 @@ def _create_columnwise(self): ) assert len(self._rowwise_scale_inv.shape) == 2 assert len(self._columnwise_scale_inv.shape) == 2 + + # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale + # is repeated 16 times (once per row in the 16x16 tile). + # We need to: + # 1. Extract tile-level scales (every 16th row) + # 2. Transpose + # 3. Expand to columnwise padded format (repeat 16 times per tile row) + + TILE_SIZE = 16 rowwise_scale_inv = self._rowwise_scale_inv - columnwise_scale_inv = rowwise_scale_inv.transpose(-2, -1) - h = min(self._columnwise_scale_inv.shape[0], columnwise_scale_inv.shape[0]) - w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1]) - self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w]) \ No newline at end of file + + # Get logical shape to compute tile counts + logical_shape = self.size() + M, K = logical_shape[0], logical_shape[-1] + M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE + + # Extract tile-level scales (take first row of each tile block) + # rowwise_scale_inv[0::16, :] gives us [M_tiles, K_tiles] (approximately) + tile_scales = rowwise_scale_inv[0:M_tiles * TILE_SIZE:TILE_SIZE, :K_tiles] # [M_tiles, K_tiles] + + # Transpose tile scales for columnwise layout + transposed_tile_scales = tile_scales.transpose(-2, -1) # [K_tiles, M_tiles] + + # Expand to columnwise padded format (repeat each tile row 16 times) + # columnwise_scale_inv has shape [K_padded, M_tiles_padded] + col_h = self._columnwise_scale_inv.shape[0] + col_w = self._columnwise_scale_inv.shape[1] + + # Zero out the columnwise scale first + self._columnwise_scale_inv.zero_() + + # Fill in the scale values, repeating each tile's scale 16 times + for tile_row in range(min(K_tiles, (col_h + TILE_SIZE - 1) // TILE_SIZE)): + row_start = tile_row * TILE_SIZE + row_end = min(row_start + TILE_SIZE, col_h) + cols_to_copy = min(M_tiles, col_w) + if row_start < col_h and cols_to_copy > 0: + # Repeat the tile row's scales for all 16 rows in this tile + self._columnwise_scale_inv[row_start:row_end, :cols_to_copy] = \ + transposed_tile_scales[tile_row, :cols_to_copy].unsqueeze(0).expand(row_end - row_start, -1) + + # Also set columnwise amax (same as rowwise since it's just transposed data) + if self._amax_columnwise is None and self._amax_rowwise is not None: + self._amax_columnwise = self._amax_rowwise.clone() + elif self._amax_rowwise is not None: + self._amax_columnwise.copy_(self._amax_rowwise) \ No newline at end of file From 9285232062fde54b8a0223f41c9884358e10ff18 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 11 Dec 2025 07:34:05 +0000 Subject: [PATCH 18/40] fix and debug --- .../test_cast_master_weights_to_fp8.py | 484 +++++++++++++++++- transformer_engine/pytorch/tensor/utils.py | 55 +- 2 files changed, 498 insertions(+), 41 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 146b7e71db..50c3b81f48 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -826,21 +826,39 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi mock_group = mock_groups[rank] linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} - nvfp4_recipe = NVFP4BlockScaling() - + # Disable stochastic rounding for deterministic gradients + nvfp4_recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + + # Original shapes (commented out for debugging padding issues): + # with te.quantized_model_init( + # enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True + # ): + # model_nvfp4 = nn.Sequential( + # te.Linear(128, 256, **linear_kwargs), + # te.Linear(256, 256 * 3, **linear_kwargs), + # te.Linear(256 * 3, 128, **linear_kwargs), + # ) + # model = nn.Sequential( + # te.Linear(128, 256, **linear_kwargs), + # te.Linear(256, 256 * 3, **linear_kwargs), + # te.Linear(256 * 3, 128, **linear_kwargs), + # ) + + # Use 2048x2048 weights to avoid NVFP4 scale_inv padding issues with te.quantized_model_init( enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True ): model_nvfp4 = nn.Sequential( - te.Linear(128, 256, **linear_kwargs), - te.Linear(256, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), ) + # BF16 model (created outside quantized_model_init) model = nn.Sequential( - te.Linear(128, 256, **linear_kwargs), - te.Linear(256, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), + te.Linear(2048, 2048, **linear_kwargs), ) for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): @@ -856,23 +874,43 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ) optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + # Add hooks to capture intermediate activations + activations_nvfp4 = {} + activations_bf16 = {} + + def make_hook(storage, name): + def hook(module, input, output): + storage[name] = (input[0].clone(), output.clone()) + return hook + + hooks = [] + for idx, (layer_nvfp4, layer_bf16) in enumerate(zip(model_nvfp4, model)): + hooks.append(layer_nvfp4.register_forward_hook(make_hook(activations_nvfp4, f"layer_{idx}"))) + hooks.append(layer_bf16.register_forward_hook(make_hook(activations_bf16, f"layer_{idx}"))) + for i in range(100): for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): w_nvfp4.main_grad.zero_() w.main_grad.zero_() + # Original input shape: torch.randn(128, 128, ...) inputs = [ - torch.randn(128, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] x = inputs[rank] - # Debug: compare master weights before forward at iteration 8 - if i == 8 and rank == 0: + # Debug: compare master weights before forward + if i in [0, 1, 7, 8] and rank == 0: print(f"\n=== Debug iteration {i} ===") for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): # Compare master weights master_nvfp4 = optimizer_nvfp4.master_weights[idx] master_bf16 = optimizer.master_weights[idx] + if master_nvfp4 is None or master_bf16 is None: + print(f"Layer {idx}: master weights = None (nvfp4={master_nvfp4 is not None}, bf16={master_bf16 is not None})") + continue + print(f"Layer {idx}: master_nvfp4.dtype={master_nvfp4.dtype}, master_bf16.dtype={master_bf16.dtype}") + print(f"Layer {idx}: w_nvfp4.dtype={w_nvfp4.dtype}, w.dtype={w.dtype}") master_match = torch.equal(master_nvfp4, master_bf16) print(f"Layer {idx}: master weights match = {master_match}") if not master_match: @@ -893,16 +931,182 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ): y = model(x) + # Debug: compare forward outputs and weight properties + if i == 0 and rank == 0: + print(f"\n=== Forward outputs iteration {i} ===") + print(f"y_nvfp4 shape: {y_nvfp4.shape}, y shape: {y.shape}") + y_match = torch.equal(y_nvfp4, y) + print(f"Forward outputs match: {y_match}") + if not y_match: + diff = (y_nvfp4 - y).abs().max().item() + print(f" max diff: {diff}") + + # Compare intermediate activations + print("\n=== Intermediate activations ===") + for layer_name in activations_nvfp4.keys(): + inp_nvfp4, out_nvfp4 = activations_nvfp4[layer_name] + inp_bf16, out_bf16 = activations_bf16[layer_name] + inp_match = torch.equal(inp_nvfp4, inp_bf16) + out_match = torch.equal(out_nvfp4, out_bf16) + print(f"{layer_name}: input match = {inp_match}, output match = {out_match}") + if not inp_match: + diff = (inp_nvfp4 - inp_bf16).abs().max().item() + print(f" input max diff: {diff}") + if not out_match: + diff = (out_nvfp4 - out_bf16).abs().max().item() + print(f" output max diff: {diff}") + + # Compare quantizer states + print("\n=== Quantizer comparison ===") + for idx, (layer_nvfp4, layer_bf16) in enumerate(zip(model_nvfp4, model)): + print(f"Layer {idx}:") + print(f" nvfp4.fp8: {layer_nvfp4.fp8}, bf16.fp8: {layer_bf16.fp8}") + print(f" nvfp4.fp8_initialized: {layer_nvfp4.fp8_initialized}, bf16.fp8_initialized: {layer_bf16.fp8_initialized}") + if hasattr(layer_nvfp4, 'quantizers') and hasattr(layer_bf16, 'quantizers'): + q_nvfp4 = layer_nvfp4.quantizers.get('scaling_fwd', []) + q_bf16 = layer_bf16.quantizers.get('scaling_fwd', []) + if q_nvfp4: + print(f" nvfp4 input quantizer type: {type(q_nvfp4[0])}") + if q_bf16: + print(f" bf16 input quantizer type: {type(q_bf16[0])}") + + # Compare NVFP4 tensor properties + print("\n=== NVFP4 weight properties ===") + for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): + print(f"Layer {idx}:") + print(f" w_nvfp4 type: {type(w_nvfp4).__name__}, w type: {type(w).__name__}") + if hasattr(w_nvfp4, '_amax_rowwise') and w_nvfp4._amax_rowwise is not None: + print(f" w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise.item()}") + if hasattr(w_nvfp4, '_amax_columnwise') and w_nvfp4._amax_columnwise is not None: + print(f" w_nvfp4._amax_columnwise: {w_nvfp4._amax_columnwise.item()}") + # Compare dequantized values + if hasattr(w_nvfp4, 'dequantize'): + w_nvfp4_dequant = w_nvfp4.dequantize(dtype=torch.bfloat16) + dequant_match = torch.equal(w_nvfp4_dequant, w) + print(f" dequant(w_nvfp4) == w: {dequant_match}") + if not dequant_match: + diff = (w_nvfp4_dequant - w).abs().max().item() + print(f" max diff: {diff}") + targets = [torch.randn_like(y) for _ in range(world_size)] target = targets[rank] loss_nvfp4 = nn.MSELoss()(y_nvfp4, target) loss = nn.MSELoss()(y, target) + # Debug: check if losses are identical + if i == 0 and rank == 0: + print(f"\n=== Loss comparison iteration {i} ===") + print(f"loss_nvfp4: {loss_nvfp4.item()}, loss: {loss.item()}") + print(f"Losses bitwise equal: {torch.equal(loss_nvfp4, loss)}") + loss_nvfp4.backward() loss.backward() + # Debug: compare gradients before optimizer step + if i == 0 and rank == 0: + print(f"\n=== Gradients before step iteration {i} ===") + for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): + grad_nvfp4 = w_nvfp4.main_grad + grad_bf16 = w.main_grad + grad_match = torch.equal(grad_nvfp4, grad_bf16) + print(f"Layer {idx}: gradients match = {grad_match}") + if not grad_match: + diff = (grad_nvfp4 - grad_bf16).abs().max().item() + print(f" max diff: {diff}") + + # Test: run same model twice to check for non-determinism + print("\n=== Determinism test: run model_nvfp4 twice ===") + for w_nvfp4 in model_nvfp4.parameters(): + w_nvfp4.main_grad.zero_() + with te.autocast(enabled=True, recipe=nvfp4_recipe, amax_reduction_group=mock_group): + y_test1 = model_nvfp4(x) + loss_test1 = nn.MSELoss()(y_test1, target) + loss_test1.backward() + grads_run1 = [w.main_grad.clone() for w in model_nvfp4.parameters()] + + for w_nvfp4 in model_nvfp4.parameters(): + w_nvfp4.main_grad.zero_() + with te.autocast(enabled=True, recipe=nvfp4_recipe, amax_reduction_group=mock_group): + y_test2 = model_nvfp4(x) + loss_test2 = nn.MSELoss()(y_test2, target) + loss_test2.backward() + grads_run2 = [w.main_grad.clone() for w in model_nvfp4.parameters()] + + for idx, (g1, g2) in enumerate(zip(grads_run1, grads_run2)): + match = torch.equal(g1, g2) + print(f"Layer {idx}: same model, 2 runs match = {match}") + if not match: + diff = (g1 - g2).abs().max().item() + print(f" max diff: {diff}") + optimizer.step() optimizer_nvfp4.step() + + # Debug: compare weights after optimizer step (on all ranks) + if i == 0: + print(f"\n=== After optimizer step iteration {i} (rank {rank}) ===") + for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): + # Compare master weights + master_nvfp4 = optimizer_nvfp4.master_weights[idx] + master_bf16 = optimizer.master_weights[idx] + if master_nvfp4 is not None and master_bf16 is not None: + master_match = torch.equal(master_nvfp4, master_bf16) + print(f"Layer {idx}: master weights match = {master_match}") + if not master_match: + diff = (master_nvfp4 - master_bf16).abs().max().item() + print(f" max diff: {diff}") + else: + print(f"Layer {idx}: master weights = None") + + # Compare model weights: quantize BF16 and compare with NVFP4 + if hasattr(w_nvfp4, '_rowwise_data') and hasattr(w_nvfp4, '_quantizer'): + # Create a fresh quantizer with same config to avoid state issues + from transformer_engine.pytorch.tensor import NVFP4Quantizer + fresh_quantizer = NVFP4Quantizer(with_2d_quantization=True) + w_bf16_quantized = fresh_quantizer(w) + + # Debug: compare amax for layer 2 + if idx == 2 and i == 0: + print(f"[Rank {rank}] Layer {idx} DEBUG:") + # Compare BF16 values: NVFP4 model's master weight vs BF16 model's weight + master_nvfp4 = optimizer_nvfp4.master_weights[idx] + master_bf16 = optimizer.master_weights[idx] + if master_nvfp4 is not None and master_bf16 is not None: + # Both ranks should have the master weight if they own this layer + bf16_from_master_nvfp4 = master_nvfp4.to(w_nvfp4.dtype).view(w.shape) + print(f"[Rank {rank}] BF16 weight w == master_bf16.to(bf16): {torch.equal(w, master_bf16.to(w.dtype).view(w.shape))}") + print(f"[Rank {rank}] BF16 weight w == master_nvfp4.to(bf16): {torch.equal(w, bf16_from_master_nvfp4)}") + else: + print(f"[Rank {rank}] master_nvfp4={master_nvfp4 is not None}, master_bf16={master_bf16 is not None}") + print(f"[Rank {rank}] w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise}") + print(f"[Rank {rank}] w_bf16_quantized._amax_rowwise: {w_bf16_quantized._amax_rowwise}") + # Sample of scales + print(f"[Rank {rank}] w_nvfp4._rowwise_scale_inv[0,:8]: {w_nvfp4._rowwise_scale_inv[0,:8].tolist()}") + print(f"[Rank {rank}] w_bf16_quantized._rowwise_scale_inv[0,:8]: {w_bf16_quantized._rowwise_scale_inv[0,:8].tolist()}") + # Check where scales differ + scale_diff = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv) + if scale_diff.any(): + diff_indices = torch.nonzero(scale_diff, as_tuple=True) + print(f"[Rank {rank}] First 5 scale diff positions: {list(zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()))}") + for r, c in zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()): + print(f"[Rank {rank}] [{r},{c}]: nvfp4={w_nvfp4._rowwise_scale_inv[r,c].item()}, ref={w_bf16_quantized._rowwise_scale_inv[r,c].item()}") + + # Compare raw NVFP4 data + data_match = torch.equal(w_nvfp4._rowwise_data, w_bf16_quantized._rowwise_data) + print(f"Layer {idx}: _rowwise_data match = {data_match}") + if not data_match: + # Count mismatches + mismatches = (w_nvfp4._rowwise_data != w_bf16_quantized._rowwise_data).sum().item() + total = w_nvfp4._rowwise_data.numel() + print(f" mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") + + # Compare scales + scale_match = torch.equal(w_nvfp4._rowwise_scale_inv, w_bf16_quantized._rowwise_scale_inv) + print(f"Layer {idx}: _rowwise_scale_inv match = {scale_match}") + if not scale_match: + mismatches = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv).sum().item() + total = w_nvfp4._rowwise_scale_inv.numel() + print(f" mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) print("iter:", i, "loss matched") @@ -932,19 +1136,19 @@ def run_parallel_tests() -> None: dp_group = dist.new_group(backend="nccl") quantizations = [] - # if is_fp8_available(): - # print("fp8 available") - # quantizations.extend(["fp8", "fp8_cs"]) - # if is_fp8_block_scaling_available(): - # quantizations.append("fp8_block") - # manual_post_all_gather_processings = [False, True] - # print("starting mini optimizer test") - # _test_mini_optimizer(dp_group) - # print("starting cast master weights to fp8 test") - # for quantization in quantizations: - # for post_ag_processing in manual_post_all_gather_processings: - # _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) - # _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + if is_fp8_available(): + print("fp8 available") + quantizations.extend(["fp8", "fp8_cs"]) + if is_fp8_block_scaling_available(): + quantizations.append("fp8_block") + manual_post_all_gather_processings = [False, True] + print("starting mini optimizer test") + _test_mini_optimizer(dp_group) + print("starting cast master weights to fp8 test") + for quantization in quantizations: + for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) print("starting cast master weights to nvfp4 test") nvfp4_available, _ = is_nvfp4_available(return_reason=True) if nvfp4_available: @@ -1119,7 +1323,7 @@ def test_nvfp4_partial_cast_matches_full() -> None: device = torch.device("cuda") # Shape must be divisible by WORLD_SIZE for even splitting # Also ensure dimensions are multiples of 16 for NVFP4 tiles - shape = (2048, 512) + shape = (4096, 2048) total_elements = shape[0] * shape[1] assert total_elements % WORLD_SIZE == 0, "Total elements must be divisible by WORLD_SIZE" @@ -1209,7 +1413,235 @@ def test_nvfp4_partial_cast_matches_full() -> None: print(f"[Rank {WORLD_RANK}] Multi-GPU NVFP4 partial cast test PASSED!") +def test_single_gpu_partial_cast_vs_full(): + """ + Single GPU test: compare cast_master_weights_to_nvfp4 (offset=0) vs quantizer(). + This isolates whether the issue is in our manual Python scale computation or elsewhere. + """ + import math + from transformer_engine.pytorch.tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_nvfp4 + import transformer_engine_torch as tex + + torch.manual_seed(12345) + device = torch.device("cuda") + + # Test with same shape as the optimizer test + shape = (2048, 64) + + # Create BF16 master weight + master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) + + # === Reference: Use NVFP4Quantizer directly === + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + ref = quantizer(master_weight) + ref_data = ref._rowwise_data.clone() + ref_scale = ref._rowwise_scale_inv.clone() + ref_amax = ref._amax_rowwise.clone() + + print(f"Reference:") + print(f" data shape: {ref_data.shape}") + print(f" scale shape: {ref_scale.shape}") + print(f" amax: {ref_amax}") + + # === Test: Use cast_master_weights_to_nvfp4 with offset=0 (full tensor) === + # Create empty NVFP4 tensor + test_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) + test_tensor._rowwise_data.zero_() + test_tensor._rowwise_scale_inv.zero_() + if test_tensor._amax_rowwise is not None: + test_tensor._amax_rowwise.zero_() + + # Create a mock distributed group for single GPU + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://", rank=0, world_size=1) + mock_group = dist.new_group(ranks=[0]) + + # Call cast_master_weights_to_nvfp4 with full tensor (offset=0) + cast_master_weights_to_nvfp4( + [test_tensor], + [master_weight.view(-1)], # Flatten as expected + [0], # offset=0 means full tensor + mock_group, + ) + + print(f"\nTest (cast_master_weights_to_nvfp4 with offset=0):") + print(f" data shape: {test_tensor._rowwise_data.shape}") + print(f" scale shape: {test_tensor._rowwise_scale_inv.shape}") + print(f" amax: {test_tensor._amax_rowwise}") + + # === Compare === + print(f"\nComparison:") + + # Compare amax + amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) + print(f" Amax match: {amax_match}") + if not amax_match: + print(f" test: {test_tensor._amax_rowwise}") + print(f" ref: {ref_amax}") + + # Compare scale + scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) + print(f" Scale match: {scale_match}") + if not scale_match: + mismatches = (test_tensor._rowwise_scale_inv != ref_scale).sum().item() + total = ref_scale.numel() + print(f" Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") + + # Compare data + data_match = torch.equal(test_tensor._rowwise_data, ref_data) + print(f" Data match: {data_match}") + if not data_match: + mismatches = (test_tensor._rowwise_data != ref_data).sum().item() + total = ref_data.numel() + print(f" Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") + + if amax_match and scale_match and data_match: + print("\nSUCCESS: cast_master_weights_to_nvfp4 (offset=0) matches quantizer!") + else: + print("\nFAILURE: Results don't match!") + + +def test_scale_computation_matches_quantizer(): + """ + Test that our Python scale computation in utils.py matches what NVFP4Quantizer produces. + This isolates the scale computation issue outside of the optimizer. + """ + import math + from transformer_engine.pytorch.tensor import NVFP4Quantizer + from transformer_engine.common.recipe import NVFP4BlockScaling + import transformer_engine_torch as tex + + torch.manual_seed(12345) + device = torch.device("cuda") + + # Test with 2048x2048 like in the optimizer test + shape = (2048, 2048) + block_len = 16 + + # Create random BF16 tensor (simulating master weight converted to BF16) + master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) + + # === Reference: Use NVFP4Quantizer === + quantizer = NVFP4Quantizer(with_2d_quantization=True) + ref = quantizer(master_weight) + ref_scale = ref._rowwise_scale_inv.clone() + ref_data = ref._rowwise_data.clone() + ref_amax = ref._amax_rowwise.clone() + + print(f"Reference scale shape: {ref_scale.shape}") + print(f"Reference data shape: {ref_data.shape}") + print(f"Reference amax: {ref_amax}") + + # === Our implementation: Replicate utils.py logic === + h, w = shape + tile_h = math.ceil(h / block_len) + tile_w = math.ceil(w / block_len) + tile_shape = (tile_h, tile_w) + + print(f"Tile shape: {tile_shape}") + + # Step 1: Compute per-block amax using CUDA kernel + amax_tensor = torch.zeros(tile_shape, dtype=torch.float32, device=device) + global_amax = torch.zeros(1, dtype=torch.float32, device=device) + + tex.nvfp4_2d_compute_partial_amax( + master_weight.view(-1), amax_tensor, h, w, 0, block_len + ) + tex.compute_amax(master_weight.view(-1), global_amax) + + print(f"Computed global_amax: {global_amax.item()}") + print(f"Reference global_amax: {ref_amax.item()}") + print(f"Global amax match: {torch.equal(global_amax, ref_amax)}") + + # Step 2: Compute scales + fp4_max = 6.0 + fp8_max = 448.0 + finfo = torch.finfo(torch.float32) + tiny = finfo.tiny + + safe_global_amax = torch.clamp(global_amax, min=tiny) + global_encode_scale = torch.clamp((fp8_max * fp4_max) / safe_global_amax, max=finfo.max) + global_scale = global_encode_scale.item() + + print(f"global_encode_scale: {global_scale}") + + # per_block_decode_scale = amax / fp4_max * global_scale + # CUDA computes: amax / 6.0 * global_scale (division first, then multiply) + # Python was: amax * (1.0 / 6.0) * global_scale (multiply by reciprocal) + # Try matching CUDA order: + per_block_decode_scale_cuda_order = torch.clamp( + (amax_tensor / fp4_max) * global_scale, max=finfo.max + ) + per_block_decode_scale_python_order = torch.clamp( + (amax_tensor * (1.0 / fp4_max)) * global_scale, max=finfo.max + ) + print(f"global_scale (should be 2688/5 = 537.6): {global_scale}") + print(f"\nComparing CUDA vs Python order:") + print(f" CUDA order [0,:5]: {per_block_decode_scale_cuda_order[0,:5].tolist()}") + print(f" Python order [0,:5]: {per_block_decode_scale_python_order[0,:5].tolist()}") + print(f" Difference: {(per_block_decode_scale_cuda_order - per_block_decode_scale_python_order).abs().max().item()}") + + # Use CUDA order for the rest of the test + per_block_decode_scale = per_block_decode_scale_cuda_order + + print(f"per_block_decode_scale shape: {per_block_decode_scale.shape}") + + # Step 3: Expand to target_scale shape (replicate utils.py expansion) + # Get the expected scale shape from quantizer (rowwise, not columnwise) + target_scale_shape = quantizer.get_scale_shape(shape, columnwise=False) + print(f"Expected target_scale shape: {target_scale_shape}") + + target_scale = torch.zeros(target_scale_shape, dtype=torch.uint8, device=device) + expanded_scale = torch.zeros(target_scale_shape, dtype=torch.float32, device=device) + + tile_rows = tile_h + tile_col_cnt = tile_w + rows = h + chunk = block_len + + for tile_row_idx in range(tile_rows): + base_row = tile_row_idx * chunk + row_end = min(base_row + chunk, rows) + if base_row >= target_scale.shape[0]: + break + expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[tile_row_idx] + + # Convert to FP8 and view as uint8 (this is the suspect operation) + fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) + target_scale.copy_(fp8_view) + + # === Compare === + print(f"\nComparing scales:") + print(f"target_scale shape: {target_scale.shape}") + print(f"ref_scale shape: {ref_scale.shape}") + + # Convert ref_scale back to FP32 to see what values it contains + ref_scale_fp32 = ref_scale.view(torch.float8_e4m3fn).to(torch.float32) + + # Check if our amax matches what the quantizer computed internally + # Reverse-engineer the reference amax from ref_scale + # ref_scale (FP8) → ref_scale_fp32 → amax = ref_scale_fp32 * 6.0 / global_scale + ref_amax_reverse = ref_scale_fp32 * fp4_max / global_scale + print(f"\nReverse-engineered ref amax [0,:5]: {ref_amax_reverse[0,:5].tolist()}") + print(f"Our computed amax [0,:5]: {amax_tensor[0,:5].tolist()}") + + # Check if amax values match + amax_match = torch.allclose(amax_tensor, ref_amax_reverse[:tile_h, :tile_w], rtol=0.01) + print(f"Amax roughly matches: {amax_match}") + + scale_match = torch.equal(target_scale, ref_scale) + print(f"\nScales match exactly: {scale_match}") + + if not scale_match: + mismatches = (target_scale != ref_scale).sum().item() + total = target_scale.numel() + print(f"Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") + + if __name__ == "__main__": #test_nvfp4_transpose_kernel() #test_nvfp4_partial_cast_matches_full() - main() + #test_scale_computation_matches_quantizer() + test_single_gpu_partial_cast_vs_full() + #main() diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index ac375a4793..17f216c573 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -577,6 +577,7 @@ def _cast_master_weights_to_nvfp4_2d( if master_weight is not None and master_weight.numel() > 0: assert len(model_weight.shape) == 2 h, w = model_weight.shape + # master_weight is already converted to model_weight.dtype (BF16) in the caller tex.nvfp4_2d_compute_partial_amax( master_weight, amax, h, w, start_offset, block_len ) @@ -649,9 +650,47 @@ def _cast_master_weights_to_nvfp4_2d( # not updated currently. model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + # Always write scales and amax for ALL layers (computed from all-reduced amax). + # This ensures scales are correct even for layers not owned by this rank. + tile_rows = tile_shape[0] + expanded_scale = torch.zeros_like(target_scale, dtype=torch.float32) + chunk = block_len + for tile_row_idx in range(tile_rows): + base_row = tile_row_idx * chunk + row_end = min(base_row + chunk, rows) + if base_row >= target_scale.shape[0]: + break + expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[ + tile_row_idx + ] + fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) + target_scale.copy_(fp8_view) + if target_amax is not None: + target_amax.copy_(global_amaxes[idx : idx + 1]) + + # Only cast data for layers owned by this rank if master_weight is None or master_weight.numel() == 0: continue + # Debug: compare scales with fresh quantizer for full cast (start_offset=0) + if start_offset == 0 and idx == 2: + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + print(f"[Rank {rank}] Layer {idx} SCALE DEBUG (start_offset=0):") + print(f"[Rank {rank}] per_block_decode_scale shape: {per_block_decode_scale.shape}") + print(f"[Rank {rank}] target_scale shape: {target_scale.shape}") + # Compare with fresh quantizer - use inline import to avoid scoping issues + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer as FreshNVFP4Quantizer + fresh_q = FreshNVFP4Quantizer(with_2d_quantization=True) + ref = fresh_q(master_weight.reshape(model_weight.shape)) + print(f"[Rank {rank}] ref._rowwise_scale_inv shape: {ref._rowwise_scale_inv.shape}") + # Check if scales match + scale_match = torch.equal(target_scale, ref._rowwise_scale_inv) + print(f"[Rank {rank}] target_scale == ref._rowwise_scale_inv: {scale_match}") + if not scale_match: + mismatches = (target_scale != ref._rowwise_scale_inv).sum().item() + total = target_scale.numel() + print(f"[Rank {rank}] mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") + end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: rowwise_bytes = model_weight._rowwise_data.view(-1) @@ -660,6 +699,7 @@ def _cast_master_weights_to_nvfp4_2d( model_weight_fragment = rowwise_bytes[byte_start:byte_end] assert len(model_weight.shape) == 2 h, w = model_weight.shape + # master_weight is already converted to model_weight.dtype (BF16) in the caller tex.nvfp4_2d_partial_cast( master_weight, model_weight_fragment, @@ -670,21 +710,6 @@ def _cast_master_weights_to_nvfp4_2d( start_offset, block_len, ) - tile_rows = tile_shape[0] - expanded_scale = torch.zeros_like(target_scale, dtype=torch.float32) - chunk = block_len - for tile_row_idx in range(tile_rows): - base_row = tile_row_idx * chunk - row_end = min(base_row + chunk, rows) - if base_row >= target_scale.shape[0]: - break - expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[ - tile_row_idx - ] - fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) - target_scale.copy_(fp8_view) - if target_amax is not None: - target_amax.copy_(global_amaxes[idx : idx + 1]) def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ From 632f54b0c13ff9e3993a16ba4bb77f33dc08d446 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 11 Dec 2025 21:17:22 +0000 Subject: [PATCH 19/40] fix loss mismatch, add debugging code --- .../test_cast_master_weights_to_fp8.py | 321 ++++++++++++------ .../common/cast/nvfp4/core_nvfp4.cuh | 39 ++- .../common/cast/nvfp4/quantize_nvfp4.cuh | 10 +- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 30 +- transformer_engine/pytorch/tensor/utils.py | 89 +++-- 5 files changed, 350 insertions(+), 139 deletions(-) mode change 100644 => 100755 transformer_engine/common/cast/nvfp4/core_nvfp4.cuh mode change 100644 => 100755 transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh mode change 100644 => 100755 transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 50c3b81f48..5812d77f0f 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -819,8 +819,8 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) + torch.manual_seed(71784) + torch.cuda.manual_seed(71784) mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] mock_group = mock_groups[rank] @@ -830,37 +830,37 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi nvfp4_recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) # Original shapes (commented out for debugging padding issues): - # with te.quantized_model_init( - # enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True - # ): - # model_nvfp4 = nn.Sequential( - # te.Linear(128, 256, **linear_kwargs), - # te.Linear(256, 256 * 3, **linear_kwargs), - # te.Linear(256 * 3, 128, **linear_kwargs), - # ) - # model = nn.Sequential( - # te.Linear(128, 256, **linear_kwargs), - # te.Linear(256, 256 * 3, **linear_kwargs), - # te.Linear(256 * 3, 128, **linear_kwargs), - # ) - - # Use 2048x2048 weights to avoid NVFP4 scale_inv padding issues with te.quantized_model_init( enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True ): model_nvfp4 = nn.Sequential( - te.Linear(2048, 2048, **linear_kwargs), - te.Linear(2048, 2048, **linear_kwargs), - te.Linear(2048, 2048, **linear_kwargs), + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), ) - - # BF16 model (created outside quantized_model_init) model = nn.Sequential( - te.Linear(2048, 2048, **linear_kwargs), - te.Linear(2048, 2048, **linear_kwargs), - te.Linear(2048, 2048, **linear_kwargs), + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), ) + # Use 2048x2048 weights to avoid NVFP4 scale_inv padding issues + # with te.quantized_model_init( + # enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True + # ): + # model_nvfp4 = nn.Sequential( + # te.Linear(2048, 2048, **linear_kwargs), + # te.Linear(2048, 2048, **linear_kwargs), + # te.Linear(2048, 2048, **linear_kwargs), + # ) + + # # BF16 model (created outside quantized_model_init) + # model = nn.Sequential( + # te.Linear(2048, 2048, **linear_kwargs), + # te.Linear(2048, 2048, **linear_kwargs), + # te.Linear(2048, 2048, **linear_kwargs), + # ) + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): high_precision_init_val = w_nvfp4.get_high_precision_init_val() w.data.copy_(high_precision_init_val) @@ -895,7 +895,7 @@ def hook(module, input, output): # Original input shape: torch.randn(128, 128, ...) inputs = [ - torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + torch.randn(2048, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] x = inputs[rank] @@ -1065,31 +1065,31 @@ def hook(module, input, output): fresh_quantizer = NVFP4Quantizer(with_2d_quantization=True) w_bf16_quantized = fresh_quantizer(w) - # Debug: compare amax for layer 2 - if idx == 2 and i == 0: - print(f"[Rank {rank}] Layer {idx} DEBUG:") - # Compare BF16 values: NVFP4 model's master weight vs BF16 model's weight - master_nvfp4 = optimizer_nvfp4.master_weights[idx] - master_bf16 = optimizer.master_weights[idx] - if master_nvfp4 is not None and master_bf16 is not None: - # Both ranks should have the master weight if they own this layer - bf16_from_master_nvfp4 = master_nvfp4.to(w_nvfp4.dtype).view(w.shape) - print(f"[Rank {rank}] BF16 weight w == master_bf16.to(bf16): {torch.equal(w, master_bf16.to(w.dtype).view(w.shape))}") - print(f"[Rank {rank}] BF16 weight w == master_nvfp4.to(bf16): {torch.equal(w, bf16_from_master_nvfp4)}") - else: - print(f"[Rank {rank}] master_nvfp4={master_nvfp4 is not None}, master_bf16={master_bf16 is not None}") - print(f"[Rank {rank}] w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise}") - print(f"[Rank {rank}] w_bf16_quantized._amax_rowwise: {w_bf16_quantized._amax_rowwise}") - # Sample of scales - print(f"[Rank {rank}] w_nvfp4._rowwise_scale_inv[0,:8]: {w_nvfp4._rowwise_scale_inv[0,:8].tolist()}") - print(f"[Rank {rank}] w_bf16_quantized._rowwise_scale_inv[0,:8]: {w_bf16_quantized._rowwise_scale_inv[0,:8].tolist()}") - # Check where scales differ - scale_diff = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv) - if scale_diff.any(): - diff_indices = torch.nonzero(scale_diff, as_tuple=True) - print(f"[Rank {rank}] First 5 scale diff positions: {list(zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()))}") - for r, c in zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()): - print(f"[Rank {rank}] [{r},{c}]: nvfp4={w_nvfp4._rowwise_scale_inv[r,c].item()}, ref={w_bf16_quantized._rowwise_scale_inv[r,c].item()}") + # # Debug: compare amax for layer 2 + # if idx == 2 and i == 0: + # print(f"[Rank {rank}] Layer {idx} DEBUG:") + # # Compare BF16 values: NVFP4 model's master weight vs BF16 model's weight + # master_nvfp4 = optimizer_nvfp4.master_weights[idx] + # master_bf16 = optimizer.master_weights[idx] + # if master_nvfp4 is not None and master_bf16 is not None: + # # Both ranks should have the master weight if they own this layer + # bf16_from_master_nvfp4 = master_nvfp4.to(w_nvfp4.dtype).view(w.shape) + # print(f"[Rank {rank}] BF16 weight w == master_bf16.to(bf16): {torch.equal(w, master_bf16.to(w.dtype).view(w.shape))}") + # print(f"[Rank {rank}] BF16 weight w == master_nvfp4.to(bf16): {torch.equal(w, bf16_from_master_nvfp4)}") + # else: + # print(f"[Rank {rank}] master_nvfp4={master_nvfp4 is not None}, master_bf16={master_bf16 is not None}") + # print(f"[Rank {rank}] w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise}") + # print(f"[Rank {rank}] w_bf16_quantized._amax_rowwise: {w_bf16_quantized._amax_rowwise}") + # # Sample of scales + # print(f"[Rank {rank}] w_nvfp4._rowwise_scale_inv[0,:8]: {w_nvfp4._rowwise_scale_inv[0,:8].tolist()}") + # print(f"[Rank {rank}] w_bf16_quantized._rowwise_scale_inv[0,:8]: {w_bf16_quantized._rowwise_scale_inv[0,:8].tolist()}") + # # Check where scales differ + # scale_diff = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv) + # if scale_diff.any(): + # diff_indices = torch.nonzero(scale_diff, as_tuple=True) + # print(f"[Rank {rank}] First 5 scale diff positions: {list(zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()))}") + # for r, c in zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()): + # print(f"[Rank {rank}] [{r},{c}]: nvfp4={w_nvfp4._rowwise_scale_inv[r,c].item()}, ref={w_bf16_quantized._rowwise_scale_inv[r,c].item()}") # Compare raw NVFP4 data data_match = torch.equal(w_nvfp4._rowwise_data, w_bf16_quantized._rowwise_data) @@ -1319,11 +1319,11 @@ def test_nvfp4_partial_cast_matches_full() -> None: if not available: pytest.skip(reason) - torch.manual_seed(1234) + torch.manual_seed(77777) device = torch.device("cuda") # Shape must be divisible by WORLD_SIZE for even splitting # Also ensure dimensions are multiples of 16 for NVFP4 tiles - shape = (4096, 2048) + shape = (4096, 4096) total_elements = shape[0] * shape[1] assert total_elements % WORLD_SIZE == 0, "Total elements must be divisible by WORLD_SIZE" @@ -1413,21 +1413,47 @@ def test_nvfp4_partial_cast_matches_full() -> None: print(f"[Rank {WORLD_RANK}] Multi-GPU NVFP4 partial cast test PASSED!") +def test_fp8_rounding_behavior(): + """Test PyTorch's FP8 rounding behavior for tie-breaking cases.""" + device = torch.device("cuda") + + # 336.0 is exactly halfway between 320.0 and 352.0 in FP8 E4M3 + test_val = torch.tensor([336.0000228881836], dtype=torch.float32, device=device) + fp8_val = test_val.to(torch.float8_e4m3fn) + back_to_fp32 = fp8_val.to(torch.float32) + + print(f"FP8 rounding test:") + print(f" Input FP32: {test_val.item()}") + print(f" FP8 as uint8: {fp8_val.view(torch.uint8).item()}") + print(f" Back to FP32: {back_to_fp32.item()}") + print(f" Expected round-to-even: 320.0 (uint8=122)") + print(f" PyTorch rounds to: {'even (320)' if back_to_fp32.item() == 320.0 else 'odd (352)'}") + + # Test a few more boundary cases + test_vals = [320.0, 336.0, 352.0, 304.0, 368.0] + for v in test_vals: + t = torch.tensor([v], dtype=torch.float32, device=device) + fp8 = t.to(torch.float8_e4m3fn) + back = fp8.to(torch.float32) + print(f" {v} -> FP8({fp8.view(torch.uint8).item()}) -> {back.item()}") + + def test_single_gpu_partial_cast_vs_full(): """ Single GPU test: compare cast_master_weights_to_nvfp4 (offset=0) vs quantizer(). This isolates whether the issue is in our manual Python scale computation or elsewhere. """ import math + import os from transformer_engine.pytorch.tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_nvfp4 import transformer_engine_torch as tex - torch.manual_seed(12345) + torch.manual_seed(77777) device = torch.device("cuda") # Test with same shape as the optimizer test - shape = (2048, 64) + shape = (2048, 2048) # Create BF16 master weight master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) @@ -1457,7 +1483,7 @@ def test_single_gpu_partial_cast_vs_full(): dist.init_process_group(backend="nccl", init_method="env://", rank=0, world_size=1) mock_group = dist.new_group(ranks=[0]) - # Call cast_master_weights_to_nvfp4 with full tensor (offset=0) + # First pass: find mismatch location cast_master_weights_to_nvfp4( [test_tensor], [master_weight.view(-1)], # Flatten as expected @@ -1465,6 +1491,37 @@ def test_single_gpu_partial_cast_vs_full(): mock_group, ) + # Check for mismatches and set debug env vars + scale_diff = (test_tensor._rowwise_scale_inv != ref_scale) + if scale_diff.any(): + diff_idx = torch.nonzero(scale_diff, as_tuple=True) + r, c = diff_idx[0][0].item(), diff_idx[1][0].item() + tile_row = r // 16 + tile_col = c + print(f"\n=== Found mismatch at scale[{r},{c}], tile[{tile_row},{tile_col}] ===") + print(f" Running second pass with debug enabled...") + + # Set env vars for debug + os.environ["NVFP4_DEBUG_TILE_ROW"] = str(tile_row) + os.environ["NVFP4_DEBUG_TILE_COL"] = str(tile_col) + + # Reset tensor and run again with debug + test_tensor._rowwise_data.zero_() + test_tensor._rowwise_scale_inv.zero_() + if test_tensor._amax_rowwise is not None: + test_tensor._amax_rowwise.zero_() + + cast_master_weights_to_nvfp4( + [test_tensor], + [master_weight.view(-1)], + [0], + mock_group, + ) + + # Clear env vars + del os.environ["NVFP4_DEBUG_TILE_ROW"] + del os.environ["NVFP4_DEBUG_TILE_COL"] + print(f"\nTest (cast_master_weights_to_nvfp4 with offset=0):") print(f" data shape: {test_tensor._rowwise_data.shape}") print(f" scale shape: {test_tensor._rowwise_scale_inv.shape}") @@ -1487,6 +1544,31 @@ def test_single_gpu_partial_cast_vs_full(): mismatches = (test_tensor._rowwise_scale_inv != ref_scale).sum().item() total = ref_scale.numel() print(f" Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") + + # Find first mismatch location + diff = (test_tensor._rowwise_scale_inv != ref_scale) + if diff.any(): + diff_idx = torch.nonzero(diff, as_tuple=True) + r, c = diff_idx[0][0].item(), diff_idx[1][0].item() + print(f"\n === First mismatch at [{r},{c}] ===") + print(f" test scale (uint8): {test_tensor._rowwise_scale_inv[r,c].item()}") + print(f" ref scale (uint8): {ref_scale[r,c].item()}") + + # Convert to FP32 to see the actual values + test_fp32 = test_tensor._rowwise_scale_inv[r,c].view(torch.float8_e4m3fn).to(torch.float32).item() + ref_fp32 = ref_scale[r,c].view(torch.float8_e4m3fn).to(torch.float32).item() + print(f" test scale (FP32): {test_fp32}") + print(f" ref scale (FP32): {ref_fp32}") + + # Compute which tile this belongs to + tile_row = r // 16 + tile_col = c + print(f" Tile position: [{tile_row}, {tile_col}]") + + # Store this for utils.py to print debug info + import os + os.environ["NVFP4_DEBUG_TILE_ROW"] = str(tile_row) + os.environ["NVFP4_DEBUG_TILE_COL"] = str(tile_col) # Compare data data_match = torch.equal(test_tensor._rowwise_data, ref_data) @@ -1512,7 +1594,7 @@ def test_scale_computation_matches_quantizer(): from transformer_engine.common.recipe import NVFP4BlockScaling import transformer_engine_torch as tex - torch.manual_seed(12345) + torch.manual_seed(77777) device = torch.device("cuda") # Test with 2048x2048 like in the optimizer test @@ -1533,6 +1615,17 @@ def test_scale_computation_matches_quantizer(): print(f"Reference data shape: {ref_data.shape}") print(f"Reference amax: {ref_amax}") + # Print reference scale for specific tile [1, 83] + debug_tile_row, debug_tile_col = 1, 83 + # Scale buffer row = tile_row * 16 (first row of the tile's 16-row block) + ref_scale_buffer_row = debug_tile_row * 16 + ref_scale_uint8 = ref_scale[ref_scale_buffer_row, debug_tile_col].item() + ref_scale_fp32 = ref_scale[ref_scale_buffer_row, debug_tile_col].view(torch.float8_e4m3fn).to(torch.float32).item() + print(f"\n=== Reference (CUDA kernel output) for tile [{debug_tile_row},{debug_tile_col}] ===") + print(f" Buffer position: ref_scale[{ref_scale_buffer_row},{debug_tile_col}]") + print(f" uint8={ref_scale_uint8}, FP32={ref_scale_fp32}") + print(f" (CUDA debug should print buffer_row={ref_scale_buffer_row}, tile_col={debug_tile_col})") + # === Our implementation: Replicate utils.py logic === h, w = shape tile_h = math.ceil(h / block_len) @@ -1557,35 +1650,67 @@ def test_scale_computation_matches_quantizer(): # Step 2: Compute scales fp4_max = 6.0 fp8_max = 448.0 - finfo = torch.finfo(torch.float32) + finfo = torch.finfo(torch.float32) tiny = finfo.tiny - safe_global_amax = torch.clamp(global_amax, min=tiny) - global_encode_scale = torch.clamp((fp8_max * fp4_max) / safe_global_amax, max=finfo.max) - global_scale = global_encode_scale.item() + # Test: use FP32 tensors for constants instead of Python floats + fp4_max_t = torch.tensor(6.0, dtype=torch.float32, device=device) + fp8_max_t = torch.tensor(448.0, dtype=torch.float32, device=device) - print(f"global_encode_scale: {global_scale}") + # Try PyTorch with FP32 tensor constants + global_encode_scale_pytorch = (fp8_max_t * fp4_max_t) / global_amax + print(f"\n=== Testing PyTorch with FP32 tensor constants ===") + print(f" fp8_max_t * fp4_max_t = {(fp8_max_t * fp4_max_t).item()}") + print(f" global_amax = {global_amax.item()}") + print(f" PyTorch FP32 tensors: (448*6)/amax = {global_encode_scale_pytorch.item():.10f}") + print(f" CUDA kernel expects: 537.5999755859") + print(f" Match: {abs(global_encode_scale_pytorch.item() - 537.5999755859) < 1e-6}") - # per_block_decode_scale = amax / fp4_max * global_scale - # CUDA computes: amax / 6.0 * global_scale (division first, then multiply) - # Python was: amax * (1.0 / 6.0) * global_scale (multiply by reciprocal) - # Try matching CUDA order: - per_block_decode_scale_cuda_order = torch.clamp( - (amax_tensor / fp4_max) * global_scale, max=finfo.max - ) - per_block_decode_scale_python_order = torch.clamp( - (amax_tensor * (1.0 / fp4_max)) * global_scale, max=finfo.max + # === USE PURE PYTORCH FP32 TENSORS === + # Key insight: Python float literals are float64, which causes precision differences. + # Using torch.float32 tensors for constants ensures FP32 computation throughout. + + # Compute global_encode_scale using PyTorch FP32 tensors + safe_global_amax_t = torch.clamp(global_amax, min=tiny) + global_encode_scale_t = torch.clamp((fp8_max_t * fp4_max_t) / safe_global_amax_t, max=finfo.max) + global_scale_t = global_encode_scale_t.item() + + print(f"\n=== Comparing global_encode_scale computation (PyTorch FP32 tensors) ===") + print(f" global_amax: {global_amax.item()}") + print(f" PyTorch FP32: (448*6)/amax = {global_scale_t}") + print(f" CUDA kernel reports S_enc = 537.5999755859") + print(f" Match: {abs(global_scale_t - 537.5999755859) < 1e-6}") + + # Compute per_block_decode_scale using PyTorch FP32 tensors + # Note: amax_tensor is already FP32, and fp4_max_t is FP32 tensor + per_block_decode_scale_t = torch.clamp( + (amax_tensor / fp4_max_t) * global_encode_scale_t, max=finfo.max ) - print(f"global_scale (should be 2688/5 = 537.6): {global_scale}") - print(f"\nComparing CUDA vs Python order:") - print(f" CUDA order [0,:5]: {per_block_decode_scale_cuda_order[0,:5].tolist()}") - print(f" Python order [0,:5]: {per_block_decode_scale_python_order[0,:5].tolist()}") - print(f" Difference: {(per_block_decode_scale_cuda_order - per_block_decode_scale_python_order).abs().max().item()}") - # Use CUDA order for the rest of the test - per_block_decode_scale = per_block_decode_scale_cuda_order + # Print Python-side values for specific tile to compare with CUDA + debug_tile_row, debug_tile_col = 1, 83 + block_amax_for_tile = amax_tensor[debug_tile_row, debug_tile_col] + print(f"\n amax_tensor shape: {amax_tensor.shape}") + print(f" amax_tensor[{debug_tile_row},{debug_tile_col}] = {block_amax_for_tile.item()}") + + # Compute S_dec_b using PyTorch FP32 tensors exactly like CUDA + s_dec_b_t = (block_amax_for_tile / fp4_max_t) * global_encode_scale_t + print(f"\n=== Python-side values for tile [{debug_tile_row},{debug_tile_col}] (using PyTorch FP32 tensors) ===") + print(f" block_amax = {block_amax_for_tile.item()}") + print(f" global_scale_t (S_enc) = {global_encode_scale_t.item()}") + print(f" fp4_max_t = {fp4_max_t.item()}") + print(f" S_dec_b = amax/fp4_max*S_enc = {s_dec_b_t.item()}") - print(f"per_block_decode_scale shape: {per_block_decode_scale.shape}") + # Convert to FP8 + s_dec_b_fp8 = s_dec_b_t.to(torch.float8_e4m3fn) + print(f" FP8 result: FP32={s_dec_b_fp8.to(torch.float32).item()}, uint8={s_dec_b_fp8.view(torch.uint8).item()}") + print(f" CUDA reported: S_dec_b=336.0, FP32=320.0, uint8=122") + print(f" MATCH: {s_dec_b_fp8.view(torch.uint8).item() == 122}") + + # Use PyTorch FP32 computed per_block_decode_scale + per_block_decode_scale = per_block_decode_scale_t + print(f"\nper_block_decode_scale shape: {per_block_decode_scale.shape}") + print(f"per_block_decode_scale[{debug_tile_row},{debug_tile_col}] = {per_block_decode_scale[debug_tile_row, debug_tile_col].item()}") # Step 3: Expand to target_scale shape (replicate utils.py expansion) # Get the expected scale shape from quantizer (rowwise, not columnwise) @@ -1607,28 +1732,20 @@ def test_scale_computation_matches_quantizer(): break expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[tile_row_idx] - # Convert to FP8 and view as uint8 (this is the suspect operation) + # Convert to FP8 and view as uint8 fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) target_scale.copy_(fp8_view) # === Compare === - print(f"\nComparing scales:") + print(f"\n=== Final comparison (using numpy FP32 computation) ===") print(f"target_scale shape: {target_scale.shape}") print(f"ref_scale shape: {ref_scale.shape}") - # Convert ref_scale back to FP32 to see what values it contains - ref_scale_fp32 = ref_scale.view(torch.float8_e4m3fn).to(torch.float32) - - # Check if our amax matches what the quantizer computed internally - # Reverse-engineer the reference amax from ref_scale - # ref_scale (FP8) → ref_scale_fp32 → amax = ref_scale_fp32 * 6.0 / global_scale - ref_amax_reverse = ref_scale_fp32 * fp4_max / global_scale - print(f"\nReverse-engineered ref amax [0,:5]: {ref_amax_reverse[0,:5].tolist()}") - print(f"Our computed amax [0,:5]: {amax_tensor[0,:5].tolist()}") - - # Check if amax values match - amax_match = torch.allclose(amax_tensor, ref_amax_reverse[:tile_h, :tile_w], rtol=0.01) - print(f"Amax roughly matches: {amax_match}") + # Check specific tile + print(f"\nTile [{debug_tile_row},{debug_tile_col}] comparison:") + buffer_row = debug_tile_row * 16 + print(f" target_scale[{buffer_row},{debug_tile_col}] = {target_scale[buffer_row, debug_tile_col].item()}") + print(f" ref_scale[{buffer_row},{debug_tile_col}] = {ref_scale[buffer_row, debug_tile_col].item()}") scale_match = torch.equal(target_scale, ref_scale) print(f"\nScales match exactly: {scale_match}") @@ -1637,11 +1754,23 @@ def test_scale_computation_matches_quantizer(): mismatches = (target_scale != ref_scale).sum().item() total = target_scale.numel() print(f"Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") + + # Find first mismatch + diff = target_scale != ref_scale + if diff.any(): + diff_idx = torch.nonzero(diff, as_tuple=True) + r, c = diff_idx[0][0].item(), diff_idx[1][0].item() + print(f"\nFirst mismatch at [{r},{c}]:") + print(f" target: {target_scale[r,c].item()}") + print(f" ref: {ref_scale[r,c].item()}") + else: + print("\n*** SUCCESS: Numpy FP32 computation matches CUDA kernel exactly! ***") if __name__ == "__main__": #test_nvfp4_transpose_kernel() #test_nvfp4_partial_cast_matches_full() #test_scale_computation_matches_quantizer() - test_single_gpu_partial_cast_vs_full() - #main() + #test_fp8_rounding_behavior() + #test_single_gpu_partial_cast_vs_full() # Test the PyTorch FP32 fix + main() diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh old mode 100644 new mode 100755 index cff8464903..96dc075169 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -38,7 +38,13 @@ namespace quantization_and_transposition_SF { // Used in transpose variant // Compute per-block E4M3 encoding/decoding scaling factor __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { + const float S_enc, + int tile_row = -1, + int tile_col = -1) { + // Print kernel name once (only for first tile, first thread) + if (tile_row == 0 && tile_col == 0) { + printf("[CUDA] Using quantization_and_transposition_SF (transpose kernel): block_amax=%.10f, S_enc=%.10f\n", block_amax, S_enc); + } // constexpr float rcp_6f = 1.0f / 6.0f; // const float S_dec_b = block_amax * rcp_6f; // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); @@ -48,7 +54,16 @@ __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const f using namespace detail; constexpr float fp4_max = TypeExtrema::max; // 6.0f; const float S_dec_b = block_amax / fp4_max * S_enc; - return static_cast(fminf(S_dec_b, TypeExtrema::max)); + nvfp4_scale_t result = static_cast(fminf(S_dec_b, TypeExtrema::max)); + // Debug: print for specific tile + // tile_row is actually buffer row (0-2047), not tile index (0-127) + // For tile [1,83], buffer row is 16 (first row of tile 1) + if (tile_row == 16 && tile_col == 83) { + printf("[CUDA transpose] buffer_row=%d, tile_col=%d (tile_row=%d): block_amax=%.10f, S_enc=%.10f, fp4_max=%.10f, S_dec_b=%.10f, result_fp32=%.10f, result_uint8=%u\n", + tile_row, tile_col, tile_row / 16, block_amax, S_enc, fp4_max, S_dec_b, + static_cast(result), static_cast(*reinterpret_cast(&result))); + } + return result; } #endif // FP4_TYPE_SUPPORTED } // namespace quantization_and_transposition_SF @@ -58,12 +73,28 @@ namespace quantization_SF { // Used in non-transpose variant // Compute per-block E4M3 encoding/decoding scaling factor __device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { + const float S_enc, + int tile_row = -1, + int tile_col = -1) { + // Print kernel name once (only for first tile, first thread) + if (tile_row == 0 && tile_col == 0) { + printf("[CUDA] Using quantization_SF (non-transpose kernel): block_amax=%.10f, S_enc=%.10f\n", block_amax, S_enc); + } constexpr float rcp_6f = 1.0f / 6.0f; // const float S_dec_b = block_amax * rcp_6f; // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); // return S_dec_b_fp8; - return static_cast(block_amax * rcp_6f * S_enc); + float S_dec_b = block_amax * rcp_6f * S_enc; + fp8e4m3 result = static_cast(S_dec_b); + // Debug: print for specific tile + // tile_row is actually buffer row (0-2047), not tile index (0-127) + // For tile [1,83], buffer row is 16 (first row of tile 1) + if (tile_row == 16 && tile_col == 83) { + printf("[CUDA quantize] buffer_row=%d, tile_col=%d (tile_row=%d): block_amax=%.10f, S_enc=%.10f, rcp_6f=%.10f, S_dec_b=%.10f, result_fp32=%.10f, result_uint8=%u\n", + tile_row, tile_col, tile_row / 16, block_amax, S_enc, rcp_6f, S_dec_b, + static_cast(result), static_cast(*reinterpret_cast(&result))); + } + return result; } #endif // FP4_TYPE_SUPPORTED } // namespace quantization_SF diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh old mode 100644 new mode 100755 index 83ad8fd40b..8eda56b33a --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -413,14 +413,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } // 2. Compute E4M3 scaling factor - const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + // Compute tile indices for debug + const int debug_tile_row = scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int debug_tile_col = scales_offset_X_rowwise; + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc, debug_tile_row, debug_tile_col); #if DIRECT_SCALING_FACTORS_STORE // Check boundaries if (rowwise_scale_is_within_bounds) { - const int scales_offset_Y = - scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = scales_offset_X_rowwise; + const int scales_offset_Y = debug_tile_row; + const int scales_offset_X = debug_tile_col; const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh old mode 100644 new mode 100755 index 7322bf2655..5e801e13c4 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -335,8 +335,11 @@ __global__ void __launch_bounds__(THREADS_NUM) } } // 2. Compute E4M3 scaling factor + // Compute tile indices for colwise (transpose direction) + const size_t colwise_tile_row = scales_offset_Y_t + stage * ITERATIONS_TRANSPOSE + it; + const size_t colwise_tile_col = scales_offset_X_t; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); + compute_decoding_scaling_factor(block_amax, S_enc_colwise, static_cast(colwise_tile_row), static_cast(colwise_tile_col)); // Store scaling factors through SHMEM const size_t scale_idx_sh = @@ -508,13 +511,14 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries + // Compute tile indices first for debug const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; const size_t scales_offset_X = scales_offset_X_rowwise; + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise, static_cast(scales_offset_Y), static_cast(scales_offset_X)); + + // Check boundaries const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; @@ -916,8 +920,11 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor + // Compute tile indices for colwise (transpose direction) in 2D kernel + const size_t colwise_tile_row_2d = scales_offset_Y_t + stage * ITERATIONS_TRANSPOSE + it; + const size_t colwise_tile_col_2d = scales_offset_X_t; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); + compute_decoding_scaling_factor(block_amax, S_enc_colwise, static_cast(colwise_tile_row_2d), static_cast(colwise_tile_col_2d)); // // Store scaling factors through SHMEM const size_t scale_idx_sh = @@ -1041,14 +1048,15 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor + // Compute tile indices first for debug + const size_t scales_offset_Y_2d = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X_2d = scales_offset_X_rowwise; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + compute_decoding_scaling_factor(block_amax, S_enc_rowwise, static_cast(scales_offset_Y_2d), static_cast(scales_offset_X_2d)); // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + const size_t scale_idx_global = scales_offset_Y_2d * scale_stride + scales_offset_X_2d; // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 17f216c573..241b09b5db 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -6,6 +6,7 @@ from typing import Optional, Union, List import math +import os import torch import transformer_engine_torch as tex @@ -596,12 +597,16 @@ def _cast_master_weights_to_nvfp4_2d( global_scale_tensor = global_amaxes.clone() if len(amaxes) > 0: finfo = torch.finfo(torch.float32) - fp4_max = 6.0 - fp8_max = 448.0 tiny = finfo.tiny + # Use FP32 tensors for constants to match CUDA kernel's FP32 computation. + # Python float literals are float64, which causes precision differences. + # Example: 2688/5 = 537.6000366211 (PyTorch with float64) vs 537.5999755859 (CUDA FP32) + fp4_max_t = torch.tensor(6.0, dtype=torch.float32, device=device) + fp8_max_t = torch.tensor(448.0, dtype=torch.float32, device=device) + safe_global_amax = torch.clamp(global_amaxes, min=tiny) - global_encode_scales = torch.clamp((fp8_max * fp4_max) / safe_global_amax, max=finfo.max) + global_encode_scales = torch.clamp((fp8_max_t * fp4_max_t) / safe_global_amax, max=finfo.max) global_encode_scales = torch.where( global_amaxes > 0, global_encode_scales, torch.ones_like(global_encode_scales) ) @@ -612,13 +617,31 @@ def _cast_master_weights_to_nvfp4_2d( ) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for amax_tensor, scale_tensor, global_scale in zip( + for layer_idx, (amax_tensor, scale_tensor, global_scale) in enumerate(zip( amaxes, scales, global_scale_views - ): + )): + # Use FP32 tensor division to match CUDA kernel exactly. + # CUDA computes: float S_dec_b = block_amax / fp4_max * S_enc; per_block_decode_scale = torch.clamp( - (amax_tensor * (1.0 / fp4_max)) * global_scale, max=finfo.max + (amax_tensor / fp4_max_t) * global_scale, max=finfo.max ) scale_tensor.copy_(per_block_decode_scale) + import os + # Debug: print specific tile values if env var is set + debug_tile_row = os.environ.get("NVFP4_DEBUG_TILE_ROW") + debug_tile_col = os.environ.get("NVFP4_DEBUG_TILE_COL") + if debug_tile_row is not None and debug_tile_col is not None and layer_idx == 0: + tr, tc = int(debug_tile_row), int(debug_tile_col) + if tr < amax_tensor.shape[0] and tc < amax_tensor.shape[1]: + print(f"\n[utils.py DEBUG] Tile [{tr},{tc}] computation (using PyTorch FP32 tensors):") + print(f" amax_tensor[{tr},{tc}] = {amax_tensor[tr,tc].item()}") + print(f" global_scale = {global_scale.item()}") + print(f" fp4_max_t = {fp4_max_t.item()}") + intermediate = amax_tensor[tr,tc] / fp4_max_t + print(f" amax / fp4_max = {intermediate.item()}") + result = intermediate * global_scale + print(f" (amax / fp4_max) * global_scale = {result.item()}") + print(f" per_block_decode_scale[{tr},{tc}] = {per_block_decode_scale[tr,tc].item()}") else: global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] @@ -665,6 +688,24 @@ def _cast_master_weights_to_nvfp4_2d( ] fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) target_scale.copy_(fp8_view) + + # Debug: print specific tile values after FP8 conversion + import os + debug_tile_row = os.environ.get("NVFP4_DEBUG_TILE_ROW") + debug_tile_col = os.environ.get("NVFP4_DEBUG_TILE_COL") + if debug_tile_row is not None and debug_tile_col is not None and idx == 0: + tr, tc = int(debug_tile_row), int(debug_tile_col) + # The scale buffer row is tile_row * 16 + offset within tile + scale_row = tr * block_len + if scale_row < target_scale.shape[0] and tc < target_scale.shape[1]: + print(f"\n[utils.py DEBUG] After FP8 conversion for tile [{tr},{tc}]:") + print(f" expanded_scale[{scale_row},{tc}] (FP32) = {expanded_scale[scale_row,tc].item()}") + print(f" fp8_view[{scale_row},{tc}] (uint8) = {fp8_view[scale_row,tc].item()}") + print(f" target_scale[{scale_row},{tc}] (uint8) = {target_scale[scale_row,tc].item()}") + # Also show as FP32 + fp8_as_fp32 = fp8_view[scale_row,tc].view(torch.float8_e4m3fn).to(torch.float32).item() + print(f" fp8_view[{scale_row},{tc}] as FP32 = {fp8_as_fp32}") + if target_amax is not None: target_amax.copy_(global_amaxes[idx : idx + 1]) @@ -672,24 +713,24 @@ def _cast_master_weights_to_nvfp4_2d( if master_weight is None or master_weight.numel() == 0: continue - # Debug: compare scales with fresh quantizer for full cast (start_offset=0) - if start_offset == 0 and idx == 2: - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - print(f"[Rank {rank}] Layer {idx} SCALE DEBUG (start_offset=0):") - print(f"[Rank {rank}] per_block_decode_scale shape: {per_block_decode_scale.shape}") - print(f"[Rank {rank}] target_scale shape: {target_scale.shape}") - # Compare with fresh quantizer - use inline import to avoid scoping issues - from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer as FreshNVFP4Quantizer - fresh_q = FreshNVFP4Quantizer(with_2d_quantization=True) - ref = fresh_q(master_weight.reshape(model_weight.shape)) - print(f"[Rank {rank}] ref._rowwise_scale_inv shape: {ref._rowwise_scale_inv.shape}") - # Check if scales match - scale_match = torch.equal(target_scale, ref._rowwise_scale_inv) - print(f"[Rank {rank}] target_scale == ref._rowwise_scale_inv: {scale_match}") - if not scale_match: - mismatches = (target_scale != ref._rowwise_scale_inv).sum().item() - total = target_scale.numel() - print(f"[Rank {rank}] mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") + # # Debug: compare scales with fresh quantizer for full cast (start_offset=0) + # if start_offset == 0 and idx == 2: + # rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + # print(f"[Rank {rank}] Layer {idx} SCALE DEBUG (start_offset=0):") + # print(f"[Rank {rank}] per_block_decode_scale shape: {per_block_decode_scale.shape}") + # print(f"[Rank {rank}] target_scale shape: {target_scale.shape}") + # # Compare with fresh quantizer - use inline import to avoid scoping issues + # from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer as FreshNVFP4Quantizer + # fresh_q = FreshNVFP4Quantizer(with_2d_quantization=True) + # ref = fresh_q(master_weight.reshape(model_weight.shape)) + # print(f"[Rank {rank}] ref._rowwise_scale_inv shape: {ref._rowwise_scale_inv.shape}") + # # Check if scales match + # scale_match = torch.equal(target_scale, ref._rowwise_scale_inv) + # print(f"[Rank {rank}] target_scale == ref._rowwise_scale_inv: {scale_match}") + # if not scale_match: + # mismatches = (target_scale != ref._rowwise_scale_inv).sum().item() + # total = target_scale.numel() + # print(f"[Rank {rank}] mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: From 433b1dbc2c46a0a2c2b120087eb1bb8bf4dcaa02 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 11 Dec 2025 21:42:38 +0000 Subject: [PATCH 20/40] clean up debugging code --- .../test_cast_master_weights_to_fp8.py | 448 +----------------- .../common/cast/nvfp4/core_nvfp4.cuh | 39 +- .../common/cast/nvfp4/quantize_nvfp4.cuh | 10 +- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 30 +- transformer_engine/pytorch/tensor/utils.py | 61 +-- 5 files changed, 26 insertions(+), 562 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 5812d77f0f..d97d982461 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -124,7 +124,6 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.offsets = [0] for weight in self.weights: self.offsets.append(self.offsets[-1] + weight.numel()) - print(f"offsets: {self.offsets}") # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may # not be the end range of the last weight. if self.offsets[-1] % self.world_size != 0: @@ -161,7 +160,8 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals # The start and end of this rank's local buffer in the global buffer rank_start = self.offsets[-1] // self.world_size * self.rank rank_end = rank_start + self.offsets[-1] // self.world_size - print(f"current rank: {self.rank}, rank_start: {rank_start}, rank_end: {rank_end}") + + # Needed for NVFP4 tensors which packs two values per byte. storage_rank_start = None storage_rank_end = None if self.weights_are_nvfp4: @@ -317,19 +317,6 @@ def step(self): storage_len = storage_overlap[1] - storage_overlap[0] weight_slice = weight[storage_start : storage_start + storage_len] overlapping_start, overlapping_end = storage_overlap - buffer_len = overlapping_end - overlapping_start - slice_len = weight_slice.numel() - if buffer_len != slice_len: - print( - "[MiniZero_1] copy mismatch:", - f"idx={i}", - f"buffer_len={buffer_len}", - f"slice_len={slice_len}", - f"weight_shape={tuple(weight.shape)}", - f"storage_start={storage_start}", - f"storage_len={storage_len}", - f"overlap=({overlapping_start},{overlapping_end})", - ) self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) continue elif isinstance(self.weights[i], QuantizedTensor): @@ -338,19 +325,6 @@ def step(self): weight = self.weights[i] weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] overlapping_start, overlapping_end = self.overlapping_areas[i] - buffer_len = overlapping_end - overlapping_start - slice_len = weight_slice.numel() - if buffer_len != slice_len: - print( - "[MiniZero_1] copy mismatch:", - f"idx={i}", - f"buffer_len={buffer_len}", - f"slice_len={slice_len}", - f"weight_shape={tuple(weight.shape)}", - f"start_offset={start_offset}", - f"master_numel={master_weight.numel()}", - f"overlap=({overlapping_start},{overlapping_end})", - ) self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) # ----------------------------------------------------------------------------------------- @@ -819,8 +793,8 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) - torch.manual_seed(71784) - torch.cuda.manual_seed(71784) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] mock_group = mock_groups[rank] @@ -844,7 +818,7 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi te.Linear(256 * 3, 128, **linear_kwargs), ) - # Use 2048x2048 weights to avoid NVFP4 scale_inv padding issues + # Use 2048x2048 weights shape for testing. # with te.quantized_model_init( # enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True # ): @@ -874,20 +848,6 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ) optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) - # Add hooks to capture intermediate activations - activations_nvfp4 = {} - activations_bf16 = {} - - def make_hook(storage, name): - def hook(module, input, output): - storage[name] = (input[0].clone(), output.clone()) - return hook - - hooks = [] - for idx, (layer_nvfp4, layer_bf16) in enumerate(zip(model_nvfp4, model)): - hooks.append(layer_nvfp4.register_forward_hook(make_hook(activations_nvfp4, f"layer_{idx}"))) - hooks.append(layer_bf16.register_forward_hook(make_hook(activations_bf16, f"layer_{idx}"))) - for i in range(100): for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): w_nvfp4.main_grad.zero_() @@ -899,24 +859,6 @@ def hook(module, input, output): ] x = inputs[rank] - # Debug: compare master weights before forward - if i in [0, 1, 7, 8] and rank == 0: - print(f"\n=== Debug iteration {i} ===") - for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): - # Compare master weights - master_nvfp4 = optimizer_nvfp4.master_weights[idx] - master_bf16 = optimizer.master_weights[idx] - if master_nvfp4 is None or master_bf16 is None: - print(f"Layer {idx}: master weights = None (nvfp4={master_nvfp4 is not None}, bf16={master_bf16 is not None})") - continue - print(f"Layer {idx}: master_nvfp4.dtype={master_nvfp4.dtype}, master_bf16.dtype={master_bf16.dtype}") - print(f"Layer {idx}: w_nvfp4.dtype={w_nvfp4.dtype}, w.dtype={w.dtype}") - master_match = torch.equal(master_nvfp4, master_bf16) - print(f"Layer {idx}: master weights match = {master_match}") - if not master_match: - diff = (master_nvfp4 - master_bf16).abs().max().item() - print(f" max diff = {diff}") - with te.autocast( enabled=True, recipe=nvfp4_recipe, @@ -931,187 +873,20 @@ def hook(module, input, output): ): y = model(x) - # Debug: compare forward outputs and weight properties - if i == 0 and rank == 0: - print(f"\n=== Forward outputs iteration {i} ===") - print(f"y_nvfp4 shape: {y_nvfp4.shape}, y shape: {y.shape}") - y_match = torch.equal(y_nvfp4, y) - print(f"Forward outputs match: {y_match}") - if not y_match: - diff = (y_nvfp4 - y).abs().max().item() - print(f" max diff: {diff}") - - # Compare intermediate activations - print("\n=== Intermediate activations ===") - for layer_name in activations_nvfp4.keys(): - inp_nvfp4, out_nvfp4 = activations_nvfp4[layer_name] - inp_bf16, out_bf16 = activations_bf16[layer_name] - inp_match = torch.equal(inp_nvfp4, inp_bf16) - out_match = torch.equal(out_nvfp4, out_bf16) - print(f"{layer_name}: input match = {inp_match}, output match = {out_match}") - if not inp_match: - diff = (inp_nvfp4 - inp_bf16).abs().max().item() - print(f" input max diff: {diff}") - if not out_match: - diff = (out_nvfp4 - out_bf16).abs().max().item() - print(f" output max diff: {diff}") - - # Compare quantizer states - print("\n=== Quantizer comparison ===") - for idx, (layer_nvfp4, layer_bf16) in enumerate(zip(model_nvfp4, model)): - print(f"Layer {idx}:") - print(f" nvfp4.fp8: {layer_nvfp4.fp8}, bf16.fp8: {layer_bf16.fp8}") - print(f" nvfp4.fp8_initialized: {layer_nvfp4.fp8_initialized}, bf16.fp8_initialized: {layer_bf16.fp8_initialized}") - if hasattr(layer_nvfp4, 'quantizers') and hasattr(layer_bf16, 'quantizers'): - q_nvfp4 = layer_nvfp4.quantizers.get('scaling_fwd', []) - q_bf16 = layer_bf16.quantizers.get('scaling_fwd', []) - if q_nvfp4: - print(f" nvfp4 input quantizer type: {type(q_nvfp4[0])}") - if q_bf16: - print(f" bf16 input quantizer type: {type(q_bf16[0])}") - - # Compare NVFP4 tensor properties - print("\n=== NVFP4 weight properties ===") - for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): - print(f"Layer {idx}:") - print(f" w_nvfp4 type: {type(w_nvfp4).__name__}, w type: {type(w).__name__}") - if hasattr(w_nvfp4, '_amax_rowwise') and w_nvfp4._amax_rowwise is not None: - print(f" w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise.item()}") - if hasattr(w_nvfp4, '_amax_columnwise') and w_nvfp4._amax_columnwise is not None: - print(f" w_nvfp4._amax_columnwise: {w_nvfp4._amax_columnwise.item()}") - # Compare dequantized values - if hasattr(w_nvfp4, 'dequantize'): - w_nvfp4_dequant = w_nvfp4.dequantize(dtype=torch.bfloat16) - dequant_match = torch.equal(w_nvfp4_dequant, w) - print(f" dequant(w_nvfp4) == w: {dequant_match}") - if not dequant_match: - diff = (w_nvfp4_dequant - w).abs().max().item() - print(f" max diff: {diff}") - targets = [torch.randn_like(y) for _ in range(world_size)] target = targets[rank] loss_nvfp4 = nn.MSELoss()(y_nvfp4, target) loss = nn.MSELoss()(y, target) - # Debug: check if losses are identical - if i == 0 and rank == 0: - print(f"\n=== Loss comparison iteration {i} ===") - print(f"loss_nvfp4: {loss_nvfp4.item()}, loss: {loss.item()}") - print(f"Losses bitwise equal: {torch.equal(loss_nvfp4, loss)}") - loss_nvfp4.backward() loss.backward() - # Debug: compare gradients before optimizer step - if i == 0 and rank == 0: - print(f"\n=== Gradients before step iteration {i} ===") - for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): - grad_nvfp4 = w_nvfp4.main_grad - grad_bf16 = w.main_grad - grad_match = torch.equal(grad_nvfp4, grad_bf16) - print(f"Layer {idx}: gradients match = {grad_match}") - if not grad_match: - diff = (grad_nvfp4 - grad_bf16).abs().max().item() - print(f" max diff: {diff}") - - # Test: run same model twice to check for non-determinism - print("\n=== Determinism test: run model_nvfp4 twice ===") - for w_nvfp4 in model_nvfp4.parameters(): - w_nvfp4.main_grad.zero_() - with te.autocast(enabled=True, recipe=nvfp4_recipe, amax_reduction_group=mock_group): - y_test1 = model_nvfp4(x) - loss_test1 = nn.MSELoss()(y_test1, target) - loss_test1.backward() - grads_run1 = [w.main_grad.clone() for w in model_nvfp4.parameters()] - - for w_nvfp4 in model_nvfp4.parameters(): - w_nvfp4.main_grad.zero_() - with te.autocast(enabled=True, recipe=nvfp4_recipe, amax_reduction_group=mock_group): - y_test2 = model_nvfp4(x) - loss_test2 = nn.MSELoss()(y_test2, target) - loss_test2.backward() - grads_run2 = [w.main_grad.clone() for w in model_nvfp4.parameters()] - - for idx, (g1, g2) in enumerate(zip(grads_run1, grads_run2)): - match = torch.equal(g1, g2) - print(f"Layer {idx}: same model, 2 runs match = {match}") - if not match: - diff = (g1 - g2).abs().max().item() - print(f" max diff: {diff}") - optimizer.step() optimizer_nvfp4.step() - - # Debug: compare weights after optimizer step (on all ranks) - if i == 0: - print(f"\n=== After optimizer step iteration {i} (rank {rank}) ===") - for idx, (w_nvfp4, w) in enumerate(zip(model_nvfp4.parameters(), model.parameters())): - # Compare master weights - master_nvfp4 = optimizer_nvfp4.master_weights[idx] - master_bf16 = optimizer.master_weights[idx] - if master_nvfp4 is not None and master_bf16 is not None: - master_match = torch.equal(master_nvfp4, master_bf16) - print(f"Layer {idx}: master weights match = {master_match}") - if not master_match: - diff = (master_nvfp4 - master_bf16).abs().max().item() - print(f" max diff: {diff}") - else: - print(f"Layer {idx}: master weights = None") - - # Compare model weights: quantize BF16 and compare with NVFP4 - if hasattr(w_nvfp4, '_rowwise_data') and hasattr(w_nvfp4, '_quantizer'): - # Create a fresh quantizer with same config to avoid state issues - from transformer_engine.pytorch.tensor import NVFP4Quantizer - fresh_quantizer = NVFP4Quantizer(with_2d_quantization=True) - w_bf16_quantized = fresh_quantizer(w) - - # # Debug: compare amax for layer 2 - # if idx == 2 and i == 0: - # print(f"[Rank {rank}] Layer {idx} DEBUG:") - # # Compare BF16 values: NVFP4 model's master weight vs BF16 model's weight - # master_nvfp4 = optimizer_nvfp4.master_weights[idx] - # master_bf16 = optimizer.master_weights[idx] - # if master_nvfp4 is not None and master_bf16 is not None: - # # Both ranks should have the master weight if they own this layer - # bf16_from_master_nvfp4 = master_nvfp4.to(w_nvfp4.dtype).view(w.shape) - # print(f"[Rank {rank}] BF16 weight w == master_bf16.to(bf16): {torch.equal(w, master_bf16.to(w.dtype).view(w.shape))}") - # print(f"[Rank {rank}] BF16 weight w == master_nvfp4.to(bf16): {torch.equal(w, bf16_from_master_nvfp4)}") - # else: - # print(f"[Rank {rank}] master_nvfp4={master_nvfp4 is not None}, master_bf16={master_bf16 is not None}") - # print(f"[Rank {rank}] w_nvfp4._amax_rowwise: {w_nvfp4._amax_rowwise}") - # print(f"[Rank {rank}] w_bf16_quantized._amax_rowwise: {w_bf16_quantized._amax_rowwise}") - # # Sample of scales - # print(f"[Rank {rank}] w_nvfp4._rowwise_scale_inv[0,:8]: {w_nvfp4._rowwise_scale_inv[0,:8].tolist()}") - # print(f"[Rank {rank}] w_bf16_quantized._rowwise_scale_inv[0,:8]: {w_bf16_quantized._rowwise_scale_inv[0,:8].tolist()}") - # # Check where scales differ - # scale_diff = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv) - # if scale_diff.any(): - # diff_indices = torch.nonzero(scale_diff, as_tuple=True) - # print(f"[Rank {rank}] First 5 scale diff positions: {list(zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()))}") - # for r, c in zip(diff_indices[0][:5].tolist(), diff_indices[1][:5].tolist()): - # print(f"[Rank {rank}] [{r},{c}]: nvfp4={w_nvfp4._rowwise_scale_inv[r,c].item()}, ref={w_bf16_quantized._rowwise_scale_inv[r,c].item()}") - - # Compare raw NVFP4 data - data_match = torch.equal(w_nvfp4._rowwise_data, w_bf16_quantized._rowwise_data) - print(f"Layer {idx}: _rowwise_data match = {data_match}") - if not data_match: - # Count mismatches - mismatches = (w_nvfp4._rowwise_data != w_bf16_quantized._rowwise_data).sum().item() - total = w_nvfp4._rowwise_data.numel() - print(f" mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") - - # Compare scales - scale_match = torch.equal(w_nvfp4._rowwise_scale_inv, w_bf16_quantized._rowwise_scale_inv) - print(f"Layer {idx}: _rowwise_scale_inv match = {scale_match}") - if not scale_match: - mismatches = (w_nvfp4._rowwise_scale_inv != w_bf16_quantized._rowwise_scale_inv).sum().item() - total = w_nvfp4._rowwise_scale_inv.numel() - print(f" mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) print("iter:", i, "loss matched") - def run_parallel_tests() -> None: """Run parallel tests""" @@ -1413,31 +1188,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: print(f"[Rank {WORLD_RANK}] Multi-GPU NVFP4 partial cast test PASSED!") -def test_fp8_rounding_behavior(): - """Test PyTorch's FP8 rounding behavior for tie-breaking cases.""" - device = torch.device("cuda") - - # 336.0 is exactly halfway between 320.0 and 352.0 in FP8 E4M3 - test_val = torch.tensor([336.0000228881836], dtype=torch.float32, device=device) - fp8_val = test_val.to(torch.float8_e4m3fn) - back_to_fp32 = fp8_val.to(torch.float32) - - print(f"FP8 rounding test:") - print(f" Input FP32: {test_val.item()}") - print(f" FP8 as uint8: {fp8_val.view(torch.uint8).item()}") - print(f" Back to FP32: {back_to_fp32.item()}") - print(f" Expected round-to-even: 320.0 (uint8=122)") - print(f" PyTorch rounds to: {'even (320)' if back_to_fp32.item() == 320.0 else 'odd (352)'}") - - # Test a few more boundary cases - test_vals = [320.0, 336.0, 352.0, 304.0, 368.0] - for v in test_vals: - t = torch.tensor([v], dtype=torch.float32, device=device) - fp8 = t.to(torch.float8_e4m3fn) - back = fp8.to(torch.float32) - print(f" {v} -> FP8({fp8.view(torch.uint8).item()}) -> {back.item()}") - - def test_single_gpu_partial_cast_vs_full(): """ Single GPU test: compare cast_master_weights_to_nvfp4 (offset=0) vs quantizer(). @@ -1584,193 +1334,5 @@ def test_single_gpu_partial_cast_vs_full(): print("\nFAILURE: Results don't match!") -def test_scale_computation_matches_quantizer(): - """ - Test that our Python scale computation in utils.py matches what NVFP4Quantizer produces. - This isolates the scale computation issue outside of the optimizer. - """ - import math - from transformer_engine.pytorch.tensor import NVFP4Quantizer - from transformer_engine.common.recipe import NVFP4BlockScaling - import transformer_engine_torch as tex - - torch.manual_seed(77777) - device = torch.device("cuda") - - # Test with 2048x2048 like in the optimizer test - shape = (2048, 2048) - block_len = 16 - - # Create random BF16 tensor (simulating master weight converted to BF16) - master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) - - # === Reference: Use NVFP4Quantizer === - quantizer = NVFP4Quantizer(with_2d_quantization=True) - ref = quantizer(master_weight) - ref_scale = ref._rowwise_scale_inv.clone() - ref_data = ref._rowwise_data.clone() - ref_amax = ref._amax_rowwise.clone() - - print(f"Reference scale shape: {ref_scale.shape}") - print(f"Reference data shape: {ref_data.shape}") - print(f"Reference amax: {ref_amax}") - - # Print reference scale for specific tile [1, 83] - debug_tile_row, debug_tile_col = 1, 83 - # Scale buffer row = tile_row * 16 (first row of the tile's 16-row block) - ref_scale_buffer_row = debug_tile_row * 16 - ref_scale_uint8 = ref_scale[ref_scale_buffer_row, debug_tile_col].item() - ref_scale_fp32 = ref_scale[ref_scale_buffer_row, debug_tile_col].view(torch.float8_e4m3fn).to(torch.float32).item() - print(f"\n=== Reference (CUDA kernel output) for tile [{debug_tile_row},{debug_tile_col}] ===") - print(f" Buffer position: ref_scale[{ref_scale_buffer_row},{debug_tile_col}]") - print(f" uint8={ref_scale_uint8}, FP32={ref_scale_fp32}") - print(f" (CUDA debug should print buffer_row={ref_scale_buffer_row}, tile_col={debug_tile_col})") - - # === Our implementation: Replicate utils.py logic === - h, w = shape - tile_h = math.ceil(h / block_len) - tile_w = math.ceil(w / block_len) - tile_shape = (tile_h, tile_w) - - print(f"Tile shape: {tile_shape}") - - # Step 1: Compute per-block amax using CUDA kernel - amax_tensor = torch.zeros(tile_shape, dtype=torch.float32, device=device) - global_amax = torch.zeros(1, dtype=torch.float32, device=device) - - tex.nvfp4_2d_compute_partial_amax( - master_weight.view(-1), amax_tensor, h, w, 0, block_len - ) - tex.compute_amax(master_weight.view(-1), global_amax) - - print(f"Computed global_amax: {global_amax.item()}") - print(f"Reference global_amax: {ref_amax.item()}") - print(f"Global amax match: {torch.equal(global_amax, ref_amax)}") - - # Step 2: Compute scales - fp4_max = 6.0 - fp8_max = 448.0 - finfo = torch.finfo(torch.float32) - tiny = finfo.tiny - - # Test: use FP32 tensors for constants instead of Python floats - fp4_max_t = torch.tensor(6.0, dtype=torch.float32, device=device) - fp8_max_t = torch.tensor(448.0, dtype=torch.float32, device=device) - - # Try PyTorch with FP32 tensor constants - global_encode_scale_pytorch = (fp8_max_t * fp4_max_t) / global_amax - print(f"\n=== Testing PyTorch with FP32 tensor constants ===") - print(f" fp8_max_t * fp4_max_t = {(fp8_max_t * fp4_max_t).item()}") - print(f" global_amax = {global_amax.item()}") - print(f" PyTorch FP32 tensors: (448*6)/amax = {global_encode_scale_pytorch.item():.10f}") - print(f" CUDA kernel expects: 537.5999755859") - print(f" Match: {abs(global_encode_scale_pytorch.item() - 537.5999755859) < 1e-6}") - - # === USE PURE PYTORCH FP32 TENSORS === - # Key insight: Python float literals are float64, which causes precision differences. - # Using torch.float32 tensors for constants ensures FP32 computation throughout. - - # Compute global_encode_scale using PyTorch FP32 tensors - safe_global_amax_t = torch.clamp(global_amax, min=tiny) - global_encode_scale_t = torch.clamp((fp8_max_t * fp4_max_t) / safe_global_amax_t, max=finfo.max) - global_scale_t = global_encode_scale_t.item() - - print(f"\n=== Comparing global_encode_scale computation (PyTorch FP32 tensors) ===") - print(f" global_amax: {global_amax.item()}") - print(f" PyTorch FP32: (448*6)/amax = {global_scale_t}") - print(f" CUDA kernel reports S_enc = 537.5999755859") - print(f" Match: {abs(global_scale_t - 537.5999755859) < 1e-6}") - - # Compute per_block_decode_scale using PyTorch FP32 tensors - # Note: amax_tensor is already FP32, and fp4_max_t is FP32 tensor - per_block_decode_scale_t = torch.clamp( - (amax_tensor / fp4_max_t) * global_encode_scale_t, max=finfo.max - ) - - # Print Python-side values for specific tile to compare with CUDA - debug_tile_row, debug_tile_col = 1, 83 - block_amax_for_tile = amax_tensor[debug_tile_row, debug_tile_col] - print(f"\n amax_tensor shape: {amax_tensor.shape}") - print(f" amax_tensor[{debug_tile_row},{debug_tile_col}] = {block_amax_for_tile.item()}") - - # Compute S_dec_b using PyTorch FP32 tensors exactly like CUDA - s_dec_b_t = (block_amax_for_tile / fp4_max_t) * global_encode_scale_t - print(f"\n=== Python-side values for tile [{debug_tile_row},{debug_tile_col}] (using PyTorch FP32 tensors) ===") - print(f" block_amax = {block_amax_for_tile.item()}") - print(f" global_scale_t (S_enc) = {global_encode_scale_t.item()}") - print(f" fp4_max_t = {fp4_max_t.item()}") - print(f" S_dec_b = amax/fp4_max*S_enc = {s_dec_b_t.item()}") - - # Convert to FP8 - s_dec_b_fp8 = s_dec_b_t.to(torch.float8_e4m3fn) - print(f" FP8 result: FP32={s_dec_b_fp8.to(torch.float32).item()}, uint8={s_dec_b_fp8.view(torch.uint8).item()}") - print(f" CUDA reported: S_dec_b=336.0, FP32=320.0, uint8=122") - print(f" MATCH: {s_dec_b_fp8.view(torch.uint8).item() == 122}") - - # Use PyTorch FP32 computed per_block_decode_scale - per_block_decode_scale = per_block_decode_scale_t - print(f"\nper_block_decode_scale shape: {per_block_decode_scale.shape}") - print(f"per_block_decode_scale[{debug_tile_row},{debug_tile_col}] = {per_block_decode_scale[debug_tile_row, debug_tile_col].item()}") - - # Step 3: Expand to target_scale shape (replicate utils.py expansion) - # Get the expected scale shape from quantizer (rowwise, not columnwise) - target_scale_shape = quantizer.get_scale_shape(shape, columnwise=False) - print(f"Expected target_scale shape: {target_scale_shape}") - - target_scale = torch.zeros(target_scale_shape, dtype=torch.uint8, device=device) - expanded_scale = torch.zeros(target_scale_shape, dtype=torch.float32, device=device) - - tile_rows = tile_h - tile_col_cnt = tile_w - rows = h - chunk = block_len - - for tile_row_idx in range(tile_rows): - base_row = tile_row_idx * chunk - row_end = min(base_row + chunk, rows) - if base_row >= target_scale.shape[0]: - break - expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[tile_row_idx] - - # Convert to FP8 and view as uint8 - fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) - target_scale.copy_(fp8_view) - - # === Compare === - print(f"\n=== Final comparison (using numpy FP32 computation) ===") - print(f"target_scale shape: {target_scale.shape}") - print(f"ref_scale shape: {ref_scale.shape}") - - # Check specific tile - print(f"\nTile [{debug_tile_row},{debug_tile_col}] comparison:") - buffer_row = debug_tile_row * 16 - print(f" target_scale[{buffer_row},{debug_tile_col}] = {target_scale[buffer_row, debug_tile_col].item()}") - print(f" ref_scale[{buffer_row},{debug_tile_col}] = {ref_scale[buffer_row, debug_tile_col].item()}") - - scale_match = torch.equal(target_scale, ref_scale) - print(f"\nScales match exactly: {scale_match}") - - if not scale_match: - mismatches = (target_scale != ref_scale).sum().item() - total = target_scale.numel() - print(f"Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") - - # Find first mismatch - diff = target_scale != ref_scale - if diff.any(): - diff_idx = torch.nonzero(diff, as_tuple=True) - r, c = diff_idx[0][0].item(), diff_idx[1][0].item() - print(f"\nFirst mismatch at [{r},{c}]:") - print(f" target: {target_scale[r,c].item()}") - print(f" ref: {ref_scale[r,c].item()}") - else: - print("\n*** SUCCESS: Numpy FP32 computation matches CUDA kernel exactly! ***") - - if __name__ == "__main__": - #test_nvfp4_transpose_kernel() - #test_nvfp4_partial_cast_matches_full() - #test_scale_computation_matches_quantizer() - #test_fp8_rounding_behavior() - #test_single_gpu_partial_cast_vs_full() # Test the PyTorch FP32 fix main() diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 96dc075169..cff8464903 100755 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -38,13 +38,7 @@ namespace quantization_and_transposition_SF { // Used in transpose variant // Compute per-block E4M3 encoding/decoding scaling factor __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, - const float S_enc, - int tile_row = -1, - int tile_col = -1) { - // Print kernel name once (only for first tile, first thread) - if (tile_row == 0 && tile_col == 0) { - printf("[CUDA] Using quantization_and_transposition_SF (transpose kernel): block_amax=%.10f, S_enc=%.10f\n", block_amax, S_enc); - } + const float S_enc) { // constexpr float rcp_6f = 1.0f / 6.0f; // const float S_dec_b = block_amax * rcp_6f; // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); @@ -54,16 +48,7 @@ __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const f using namespace detail; constexpr float fp4_max = TypeExtrema::max; // 6.0f; const float S_dec_b = block_amax / fp4_max * S_enc; - nvfp4_scale_t result = static_cast(fminf(S_dec_b, TypeExtrema::max)); - // Debug: print for specific tile - // tile_row is actually buffer row (0-2047), not tile index (0-127) - // For tile [1,83], buffer row is 16 (first row of tile 1) - if (tile_row == 16 && tile_col == 83) { - printf("[CUDA transpose] buffer_row=%d, tile_col=%d (tile_row=%d): block_amax=%.10f, S_enc=%.10f, fp4_max=%.10f, S_dec_b=%.10f, result_fp32=%.10f, result_uint8=%u\n", - tile_row, tile_col, tile_row / 16, block_amax, S_enc, fp4_max, S_dec_b, - static_cast(result), static_cast(*reinterpret_cast(&result))); - } - return result; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); } #endif // FP4_TYPE_SUPPORTED } // namespace quantization_and_transposition_SF @@ -73,28 +58,12 @@ namespace quantization_SF { // Used in non-transpose variant // Compute per-block E4M3 encoding/decoding scaling factor __device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, - const float S_enc, - int tile_row = -1, - int tile_col = -1) { - // Print kernel name once (only for first tile, first thread) - if (tile_row == 0 && tile_col == 0) { - printf("[CUDA] Using quantization_SF (non-transpose kernel): block_amax=%.10f, S_enc=%.10f\n", block_amax, S_enc); - } + const float S_enc) { constexpr float rcp_6f = 1.0f / 6.0f; // const float S_dec_b = block_amax * rcp_6f; // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); // return S_dec_b_fp8; - float S_dec_b = block_amax * rcp_6f * S_enc; - fp8e4m3 result = static_cast(S_dec_b); - // Debug: print for specific tile - // tile_row is actually buffer row (0-2047), not tile index (0-127) - // For tile [1,83], buffer row is 16 (first row of tile 1) - if (tile_row == 16 && tile_col == 83) { - printf("[CUDA quantize] buffer_row=%d, tile_col=%d (tile_row=%d): block_amax=%.10f, S_enc=%.10f, rcp_6f=%.10f, S_dec_b=%.10f, result_fp32=%.10f, result_uint8=%u\n", - tile_row, tile_col, tile_row / 16, block_amax, S_enc, rcp_6f, S_dec_b, - static_cast(result), static_cast(*reinterpret_cast(&result))); - } - return result; + return static_cast(block_amax * rcp_6f * S_enc); } #endif // FP4_TYPE_SUPPORTED } // namespace quantization_SF diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index 8eda56b33a..83ad8fd40b 100755 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -413,16 +413,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } // 2. Compute E4M3 scaling factor - // Compute tile indices for debug - const int debug_tile_row = scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int debug_tile_col = scales_offset_X_rowwise; - const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc, debug_tile_row, debug_tile_col); + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); #if DIRECT_SCALING_FACTORS_STORE // Check boundaries if (rowwise_scale_is_within_bounds) { - const int scales_offset_Y = debug_tile_row; - const int scales_offset_X = debug_tile_col; + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 5e801e13c4..7322bf2655 100755 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -335,11 +335,8 @@ __global__ void __launch_bounds__(THREADS_NUM) } } // 2. Compute E4M3 scaling factor - // Compute tile indices for colwise (transpose direction) - const size_t colwise_tile_row = scales_offset_Y_t + stage * ITERATIONS_TRANSPOSE + it; - const size_t colwise_tile_col = scales_offset_X_t; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise, static_cast(colwise_tile_row), static_cast(colwise_tile_col)); + compute_decoding_scaling_factor(block_amax, S_enc_colwise); // Store scaling factors through SHMEM const size_t scale_idx_sh = @@ -511,14 +508,13 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor - // Compute tile indices first for debug - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise, static_cast(scales_offset_Y), static_cast(scales_offset_X)); + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; @@ -920,11 +916,8 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor - // Compute tile indices for colwise (transpose direction) in 2D kernel - const size_t colwise_tile_row_2d = scales_offset_Y_t + stage * ITERATIONS_TRANSPOSE + it; - const size_t colwise_tile_col_2d = scales_offset_X_t; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise, static_cast(colwise_tile_row_2d), static_cast(colwise_tile_col_2d)); + compute_decoding_scaling_factor(block_amax, S_enc_colwise); // // Store scaling factors through SHMEM const size_t scale_idx_sh = @@ -1048,15 +1041,14 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor - // Compute tile indices first for debug - const size_t scales_offset_Y_2d = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X_2d = scales_offset_X_rowwise; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise, static_cast(scales_offset_Y_2d), static_cast(scales_offset_X_2d)); + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); // Check boundaries - const size_t scale_idx_global = scales_offset_Y_2d * scale_stride + scales_offset_X_2d; + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 241b09b5db..365f11bc50 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -6,7 +6,6 @@ from typing import Optional, Union, List import math -import os import torch import transformer_engine_torch as tex @@ -611,37 +610,17 @@ def _cast_master_weights_to_nvfp4_2d( global_amaxes > 0, global_encode_scales, torch.ones_like(global_encode_scales) ) global_scale_tensor.copy_(global_encode_scales) - print( - "[NVFP4 partial cast] global_encode_scales:", - [(idx, float(val)) for idx, val in enumerate(global_encode_scales.tolist())], - ) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for layer_idx, (amax_tensor, scale_tensor, global_scale) in enumerate(zip( + for amax_tensor, scale_tensor, global_scale in zip( amaxes, scales, global_scale_views - )): + ): # Use FP32 tensor division to match CUDA kernel exactly. # CUDA computes: float S_dec_b = block_amax / fp4_max * S_enc; per_block_decode_scale = torch.clamp( (amax_tensor / fp4_max_t) * global_scale, max=finfo.max ) scale_tensor.copy_(per_block_decode_scale) - import os - # Debug: print specific tile values if env var is set - debug_tile_row = os.environ.get("NVFP4_DEBUG_TILE_ROW") - debug_tile_col = os.environ.get("NVFP4_DEBUG_TILE_COL") - if debug_tile_row is not None and debug_tile_col is not None and layer_idx == 0: - tr, tc = int(debug_tile_row), int(debug_tile_col) - if tr < amax_tensor.shape[0] and tc < amax_tensor.shape[1]: - print(f"\n[utils.py DEBUG] Tile [{tr},{tc}] computation (using PyTorch FP32 tensors):") - print(f" amax_tensor[{tr},{tc}] = {amax_tensor[tr,tc].item()}") - print(f" global_scale = {global_scale.item()}") - print(f" fp4_max_t = {fp4_max_t.item()}") - intermediate = amax_tensor[tr,tc] / fp4_max_t - print(f" amax / fp4_max = {intermediate.item()}") - result = intermediate * global_scale - print(f" (amax / fp4_max) * global_scale = {result.item()}") - print(f" per_block_decode_scale[{tr},{tc}] = {per_block_decode_scale[tr,tc].item()}") else: global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] @@ -689,23 +668,6 @@ def _cast_master_weights_to_nvfp4_2d( fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) target_scale.copy_(fp8_view) - # Debug: print specific tile values after FP8 conversion - import os - debug_tile_row = os.environ.get("NVFP4_DEBUG_TILE_ROW") - debug_tile_col = os.environ.get("NVFP4_DEBUG_TILE_COL") - if debug_tile_row is not None and debug_tile_col is not None and idx == 0: - tr, tc = int(debug_tile_row), int(debug_tile_col) - # The scale buffer row is tile_row * 16 + offset within tile - scale_row = tr * block_len - if scale_row < target_scale.shape[0] and tc < target_scale.shape[1]: - print(f"\n[utils.py DEBUG] After FP8 conversion for tile [{tr},{tc}]:") - print(f" expanded_scale[{scale_row},{tc}] (FP32) = {expanded_scale[scale_row,tc].item()}") - print(f" fp8_view[{scale_row},{tc}] (uint8) = {fp8_view[scale_row,tc].item()}") - print(f" target_scale[{scale_row},{tc}] (uint8) = {target_scale[scale_row,tc].item()}") - # Also show as FP32 - fp8_as_fp32 = fp8_view[scale_row,tc].view(torch.float8_e4m3fn).to(torch.float32).item() - print(f" fp8_view[{scale_row},{tc}] as FP32 = {fp8_as_fp32}") - if target_amax is not None: target_amax.copy_(global_amaxes[idx : idx + 1]) @@ -713,25 +675,6 @@ def _cast_master_weights_to_nvfp4_2d( if master_weight is None or master_weight.numel() == 0: continue - # # Debug: compare scales with fresh quantizer for full cast (start_offset=0) - # if start_offset == 0 and idx == 2: - # rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - # print(f"[Rank {rank}] Layer {idx} SCALE DEBUG (start_offset=0):") - # print(f"[Rank {rank}] per_block_decode_scale shape: {per_block_decode_scale.shape}") - # print(f"[Rank {rank}] target_scale shape: {target_scale.shape}") - # # Compare with fresh quantizer - use inline import to avoid scoping issues - # from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer as FreshNVFP4Quantizer - # fresh_q = FreshNVFP4Quantizer(with_2d_quantization=True) - # ref = fresh_q(master_weight.reshape(model_weight.shape)) - # print(f"[Rank {rank}] ref._rowwise_scale_inv shape: {ref._rowwise_scale_inv.shape}") - # # Check if scales match - # scale_match = torch.equal(target_scale, ref._rowwise_scale_inv) - # print(f"[Rank {rank}] target_scale == ref._rowwise_scale_inv: {scale_match}") - # if not scale_match: - # mismatches = (target_scale != ref._rowwise_scale_inv).sum().item() - # total = target_scale.numel() - # print(f"[Rank {rank}] mismatches: {mismatches}/{total} ({100*mismatches/total:.2f}%)") - end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: rowwise_bytes = model_weight._rowwise_data.view(-1) From cc1d021339c8f6760d2f8459e5f30b3221858aa7 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 11 Dec 2025 22:12:09 +0000 Subject: [PATCH 21/40] minor --- transformer_engine/pytorch/tensor/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 365f11bc50..c12aa16cec 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -588,10 +588,6 @@ def _cast_master_weights_to_nvfp4_2d( if global_amaxes.numel() > 0: torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) - print( - "[NVFP4 partial cast] global_amaxes:", - [(idx, float(val)) for idx, val in enumerate(global_amaxes.tolist())], - ) global_scale_tensor = global_amaxes.clone() if len(amaxes) > 0: From 48945edb9ff51bbda14c5c8f1e00edaa481a8bb0 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 11 Dec 2025 23:44:09 +0000 Subject: [PATCH 22/40] clean up again --- transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 0 transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | 0 .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 0 transformer_engine/common/gemm/cublaslt_gemm.cu | 3 +-- transformer_engine/pytorch/csrc/extensions/recipe.cpp | 2 -- transformer_engine/pytorch/module/base.py | 2 -- 6 files changed, 1 insertion(+), 6 deletions(-) mode change 100755 => 100644 transformer_engine/common/cast/nvfp4/core_nvfp4.cuh mode change 100755 => 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh mode change 100755 => 100644 transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh old mode 100755 new mode 100644 diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh old mode 100755 new mode 100644 diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh old mode 100755 new mode 100644 diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ed4be7acd9..d61f3ce219 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -736,8 +736,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); NVTE_CHECK(new_workspace_alignment % 256 == 0, "cuBLAS workspace pointer must be aligned to 256 bytes, got ", - new_workspace_alignment); - + new_workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 278cfa205a..8d1d865604 100755 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -67,5 +67,3 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio } } // namespace transformer_engine::pytorch - -namespace transformer_engine::pytorch {} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index db9e8bd0a6..379f854e9a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -52,7 +52,6 @@ get_nvtx_range_context, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from ..tensor.nvfp4_tensor import NVFP4Quantizer from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -1277,7 +1276,6 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) - # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already From d458d5b21b362066fdaf1c7c1d9914fe6e53ce57 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 12 Dec 2025 00:35:09 +0000 Subject: [PATCH 23/40] more restrict shape in test --- .../distributed/test_cast_master_weights_to_fp8.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index d97d982461..e3fb0b03b5 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -803,18 +803,18 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi # Disable stochastic rounding for deterministic gradients nvfp4_recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) - # Original shapes (commented out for debugging padding issues): with te.quantized_model_init( enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True ): model_nvfp4 = nn.Sequential( - te.Linear(128, 256, **linear_kwargs), - te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(128, 256+64, **linear_kwargs), + te.Linear(256+64, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) + # Create model with bf16 weights model = nn.Sequential( - te.Linear(128, 256, **linear_kwargs), - te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(128, 256+64, **linear_kwargs), + te.Linear(256+64, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) From d87daf9d5b6dee3b1c27da56bac07c30292f0494 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 30 Dec 2025 19:56:34 +0000 Subject: [PATCH 24/40] nvfp4 scale transpose --- .../include/transformer_engine/recipe.h | 15 ++++ transformer_engine/common/recipe/nvfp4.cu | 86 +++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/pybind.cpp | 4 + .../pytorch/csrc/extensions/transpose.cpp | 23 +++++ .../tensor/storage/nvfp4_tensor_storage.py | 42 +++------ 6 files changed, 140 insertions(+), 32 deletions(-) mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/pybind.cpp diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 82cddcc9eb..47ae29464e 100755 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -331,6 +331,21 @@ void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTE */ void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Transpose NVFP4 tile-level scales from rowwise to columnwise format. + * + * Takes rowwise_scale_inv where scales are stored at every 16th row (tile boundaries) + * and produces columnwise_scale_inv where scales are repeated 16 times per tile row. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * \param[in] input Input tensor with rowwise scales [M_padded, K_tiles], uint8 (E4M3). + * \param[out] output Output tensor with columnwise scales [K_padded, M_tiles], uint8 (E4M3). + * \param[in] M_tiles Number of tiles in M dimension. + * \param[in] K_tiles Number of tiles in K dimension. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, + size_t M_tiles, size_t K_tiles, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index c028c3b13e..76c1e743b7 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -423,9 +423,95 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { NVTE_CHECK_CUDA(cudaGetLastError()); } +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE TRANSPOSE KERNEL + * + * Transposes tile-level scales from rowwise to columnwise format. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * Input (rowwise_scale_inv): [M_padded, K_tiles] where scales are stored + * at every 16th row (i.e., row 0, 16, 32, ... contain the actual scales, + * and each row i within a tile block has the same scale as row (i // 16) * 16). + * + * Output (columnwise_scale_inv): [K_padded, M_tiles] where scales are + * repeated 16 times per tile row. + * + * Mapping: + * output[k_tile * 16 + i, m_tile] = input[m_tile * 16, k_tile] + * for i in [0, 16) and valid (k_tile, m_tile) indices. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_scale_transpose_kernel( + const uint8_t* __restrict__ input, // [M_padded, K_tiles], E4M3 stored as uint8 + uint8_t* __restrict__ output, // [K_padded, M_tiles], E4M3 stored as uint8 + const size_t M_tiles, // Number of M tiles + const size_t K_tiles, // Number of K tiles + const size_t input_stride, // K_tiles (input row stride) + const size_t output_stride, // M_tiles (output row stride) + const size_t K_padded // Output height +) { + // Each thread handles one output element + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= K_padded || out_col >= M_tiles) return; + + // Determine which tile row this belongs to + const size_t k_tile = out_row / kTileDim; + + // Read from input: row = m_tile * 16 (first row of the tile), col = k_tile + // m_tile = out_col + if (k_tile < K_tiles) { + const size_t in_row = out_col * kTileDim; // m_tile * 16 + const uint8_t scale = input[in_row * input_stride + k_tile]; + output[out_row * output_stride + out_col] = scale; + } else { + output[out_row * output_stride + out_col] = 0; + } +} + +void nvfp4_scale_transpose(const Tensor input, Tensor output, + size_t M_tiles, size_t K_tiles, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + + const auto in_shape = input.shape(); + const auto out_shape = output.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + + const size_t input_stride = in_shape[1]; // K_tiles + const size_t output_stride = out_shape[1]; // M_tiles + const size_t K_padded = out_shape[0]; + + if (M_tiles == 0 || K_tiles == 0 || K_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((M_tiles + kBlockDim - 1) / kBlockDim, + (K_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_scale_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), + M_tiles, K_tiles, input_stride, output_stride, K_padded); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, + size_t M_tiles, size_t K_tiles, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_scale_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_scale_transpose(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), + M_tiles, K_tiles, stream); +} + void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_transpose); using namespace transformer_engine; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100644 new mode 100755 index 669bbc8df9..868a8d2d44 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -158,6 +158,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor nvfp4_transpose(at::Tensor input, std::optional output = std::nullopt); +void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, int64_t K_tiles); + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100644 new mode 100755 index 364f1484c4..61511d807d --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -257,6 +257,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvfp4_transpose", &transformer_engine::pytorch::nvfp4_transpose, "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("nvfp4_scale_transpose", &transformer_engine::pytorch::nvfp4_scale_transpose, + "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", + py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), + py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index ad6c50f065..898552e53b 100755 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -99,6 +99,29 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output) { return out; } +void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, + int64_t M_tiles, int64_t K_tiles) { + init_extension(); + + // Input: rowwise_scale_inv [M_padded, K_tiles], uint8 (E4M3 stored as bytes) + // Output: columnwise_scale_inv [K_padded, M_tiles], uint8 (E4M3 stored as bytes) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kByte); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), + static_cast(M_tiles), static_cast(K_tiles), + at::cuda::getCurrentCUDAStream()); +} + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 0fb24fb812..0eb6b398db 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -356,44 +356,22 @@ def _create_columnwise(self): # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale # is repeated 16 times (once per row in the 16x16 tile). - # We need to: - # 1. Extract tile-level scales (every 16th row) - # 2. Transpose - # 3. Expand to columnwise padded format (repeat 16 times per tile row) - + # columnwise_scale_inv has shape [K_padded, M_tiles] where scales are + # repeated 16 times per tile row. + # + # Use GPU kernel to efficiently transpose and expand the scales. TILE_SIZE = 16 - rowwise_scale_inv = self._rowwise_scale_inv - - # Get logical shape to compute tile counts logical_shape = self.size() M, K = logical_shape[0], logical_shape[-1] M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE - # Extract tile-level scales (take first row of each tile block) - # rowwise_scale_inv[0::16, :] gives us [M_tiles, K_tiles] (approximately) - tile_scales = rowwise_scale_inv[0:M_tiles * TILE_SIZE:TILE_SIZE, :K_tiles] # [M_tiles, K_tiles] - - # Transpose tile scales for columnwise layout - transposed_tile_scales = tile_scales.transpose(-2, -1) # [K_tiles, M_tiles] - - # Expand to columnwise padded format (repeat each tile row 16 times) - # columnwise_scale_inv has shape [K_padded, M_tiles_padded] - col_h = self._columnwise_scale_inv.shape[0] - col_w = self._columnwise_scale_inv.shape[1] - - # Zero out the columnwise scale first - self._columnwise_scale_inv.zero_() - - # Fill in the scale values, repeating each tile's scale 16 times - for tile_row in range(min(K_tiles, (col_h + TILE_SIZE - 1) // TILE_SIZE)): - row_start = tile_row * TILE_SIZE - row_end = min(row_start + TILE_SIZE, col_h) - cols_to_copy = min(M_tiles, col_w) - if row_start < col_h and cols_to_copy > 0: - # Repeat the tile row's scales for all 16 rows in this tile - self._columnwise_scale_inv[row_start:row_end, :cols_to_copy] = \ - transposed_tile_scales[tile_row, :cols_to_copy].unsqueeze(0).expand(row_end - row_start, -1) + tex.nvfp4_scale_transpose( + self._rowwise_scale_inv, + self._columnwise_scale_inv, + M_tiles, + K_tiles, + ) # Also set columnwise amax (same as rowwise since it's just transposed data) if self._amax_columnwise is None and self._amax_rowwise is not None: From 938079d98d09b864176aac3af57780207ad09a63 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 5 Jan 2026 19:01:19 +0000 Subject: [PATCH 25/40] partial cast optimize --- .../test_cast_master_weights_to_fp8.py | 3 +- .../include/transformer_engine/recipe.h | 45 +++++ transformer_engine/common/recipe/nvfp4.cu | 187 ++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 8 + .../pytorch/csrc/extensions/pybind.cpp | 13 ++ transformer_engine/pytorch/tensor/utils.py | 65 +++--- 6 files changed, 279 insertions(+), 42 deletions(-) mode change 100644 => 100755 tests/pytorch/distributed/test_cast_master_weights_to_fp8.py mode change 100644 => 100755 transformer_engine/pytorch/tensor/utils.py diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py old mode 100644 new mode 100755 index 53f77d9e4d..46fb51645a --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1398,4 +1398,5 @@ def test_single_gpu_partial_cast_vs_full(): if __name__ == "__main__": - main() + #main() + test_nvfp4_transpose_kernel() \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 47ae29464e..8d0633e4b5 100755 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -346,6 +346,51 @@ void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_ void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, size_t K_tiles, cudaStream_t stream); +/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3. + * + * Each tile row's scale is repeated block_len times in the output. + * + * \param[in] input Input tensor with tile scales [tile_rows, tile_cols], float32. + * \param[out] output Output tensor with expanded scales [rows_padded, tile_cols], uint8 (E4M3). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Padded row count in output. + * \param[in] block_len Block length (typically 16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream); + +/*! \brief Compute per-block decode scale from block amax and global amax. + * + * Computes: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[out] scale Output scale tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax value (per-tensor, after all-reduce). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + float global_amax, cudaStream_t stream); + +/*! \brief Compute global encode scale from global amax. + * + * Computes: global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * If global_amax <= 0, returns 1.0. + * + * \param[in] global_amax Input global amax tensor [num_params], float32. + * \param[out] global_scale Output global scale tensor [num_params], float32. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 76c1e743b7..e0a5025f3e 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -500,9 +500,196 @@ void nvfp4_scale_transpose(const Tensor input, Tensor output, NVTE_CHECK_CUDA(cudaGetLastError()); } +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE EXPANSION KERNEL + * + * Expands tile-level scales to row-level scales and converts to FP8 E4M3. + * + * Input (per_block_decode_scale): [tile_rows, tile_cols] in float32 + * Output (target_scale): [rows_padded, tile_cols] in uint8 (E4M3) + * + * Each tile row's scale is repeated block_len times in the output. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_expand_scale_to_fp8_kernel( + const float* __restrict__ input, // [tile_rows, tile_cols] + uint8_t* __restrict__ output, // [rows_padded, tile_cols] + const size_t tile_rows, + const size_t tile_cols, + const size_t rows_padded, + const size_t block_len +) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + scale_val = input[tile_row * tile_cols + out_col]; + } + + // Convert float32 to FP8 E4M3 + // Clamp to FP8 E4M3 range and convert + fp8e4m3 fp8_val = static_cast(scale_val); + output[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kFloat32, "Scale input must be float32."); + NVTE_CHECK(output.dtype() == DType::kByte, "Scale output must be uint8 (E4M3)."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, + (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_expand_scale_to_fp8_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), + tile_rows, tile_cols, rows_padded, block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 COMPUTE PER-BLOCK DECODE SCALE KERNEL + * + * Computes per-block decode scale from block amax and global amax: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * = block_amax * 448 / global_amax + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh + * + * Input (block_amax): [tile_rows, tile_cols] in float32 + * Input (global_amax): scalar float32 (per-tensor amax after all-reduce) + * Output (scale): [tile_rows, tile_cols] in float32 + * Output (global_scale_out): scalar float32 (the computed global encode scale) + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_compute_per_block_scale_kernel( + const float* __restrict__ block_amax, // [tile_rows, tile_cols] + float* __restrict__ scale, // [tile_rows, tile_cols] + const float global_amax, + const size_t numel +) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax + float safe_global_amax = fmaxf(global_amax, tiny); + float global_scale = (global_amax > 0.0f) ? + fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Compute per-block decode scale: S_dec_b = block_amax / fp4_max * S_enc + float amax_val = block_amax[idx]; + float result = fminf((amax_val / fp4_max) * global_scale, flt_max); + scale[idx] = result; +} + +// Simple kernel to compute global encode scale from global amax +__global__ void nvfp4_compute_global_scale_kernel( + const float* __restrict__ global_amax, // [num_params] + float* __restrict__ global_scale, // [num_params] + const size_t num_params +) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_params) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + float amax = global_amax[idx]; + float safe_amax = fmaxf(amax, tiny); + float scale = (amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_amax, flt_max) : 1.0f; + global_scale[idx] = scale; +} + +void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, + float global_amax, cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); + + size_t numel = block_amax.numel(); + if (numel == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (numel + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_per_block_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(scale.data.dptr), + global_amax, numel); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, + cudaStream_t stream) { + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_scale.dtype() == DType::kFloat32, "Global scale must be float32."); + + size_t num_params = global_amax.numel(); + if (num_params == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (num_params + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_global_scale_kernel<<>>( + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(global_scale.data.dptr), + num_params); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_expand_scale_to_fp8); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_expand_scale_to_fp8(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), + tile_rows, tile_cols, rows_padded, block_len, stream); +} + +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + float global_amax, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), + *convertNVTETensorCheck(scale), + global_amax, stream); +} + +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_global_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_global_scale(*convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(global_scale), + stream); +} + void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, size_t K_tiles, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_scale_transpose); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 868a8d2d44..0c209dc77a 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -160,6 +160,14 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output = void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, int64_t K_tiles); +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, + int64_t tile_rows, int64_t tile_cols, + int64_t rows_padded, int64_t block_len); + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, float global_amax); + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 61511d807d..439566462a 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -261,6 +261,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), py::call_guard()); + m.def("nvfp4_expand_scale_to_fp8", &transformer_engine::pytorch::nvfp4_expand_scale_to_fp8, + "Expand tile-level scales to row-level scales and convert to FP8 E4M3", + py::arg("input"), py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), + py::arg("rows_padded"), py::arg("block_len"), + py::call_guard()); + m.def("nvfp4_compute_per_block_scale", &transformer_engine::pytorch::nvfp4_compute_per_block_scale, + "Compute per-block decode scale from block amax and global amax", + py::arg("block_amax"), py::arg("scale"), py::arg("global_amax"), + py::call_guard()); + m.def("nvfp4_compute_global_scale", &transformer_engine::pytorch::nvfp4_compute_global_scale, + "Compute global encode scale from global amax", + py::arg("global_amax"), py::arg("global_scale"), + py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100644 new mode 100755 index a93475b39e..00f3bd6757 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -221,11 +221,11 @@ def _cast_master_weights_to_fp8_delayed_scaling( for model_weight, master_weight, start_offset, shard_model_weight_raw in params: if not manual_post_all_gather_processing: - # Reset transpose cache for all model weights. + # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. - model_weight._reset_caches() + model_weight._reset_caches() quantizer = model_weight._get_quantizer() @@ -367,11 +367,11 @@ def _cast_master_weights_to_fp8_current_scaling( params, scales ): if not manual_post_all_gather_processing: - # Reset transpose cache for all model weights. + # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. - model_weight._reset_caches() + model_weight._reset_caches() # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. @@ -502,7 +502,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # We cannot create columnwise data here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated at this moment. - model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. @@ -596,34 +596,19 @@ def _cast_master_weights_to_nvfp4_2d( if global_amaxes.numel() > 0: torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) - global_scale_tensor = global_amaxes.clone() + # Use GPU kernel to compute global encode scales from global amaxes + # This replaces multiple Python tensor operations with a single kernel + global_scale_tensor = torch.empty_like(global_amaxes) if len(amaxes) > 0: - finfo = torch.finfo(torch.float32) - tiny = finfo.tiny - - # Use FP32 tensors for constants to match CUDA kernel's FP32 computation. - # Python float literals are float64, which causes precision differences. - # Example: 2688/5 = 537.6000366211 (PyTorch with float64) vs 537.5999755859 (CUDA FP32) - fp4_max_t = torch.tensor(6.0, dtype=torch.float32, device=device) - fp8_max_t = torch.tensor(448.0, dtype=torch.float32, device=device) - - safe_global_amax = torch.clamp(global_amaxes, min=tiny) - global_encode_scales = torch.clamp((fp8_max_t * fp4_max_t) / safe_global_amax, max=finfo.max) - global_encode_scales = torch.where( - global_amaxes > 0, global_encode_scales, torch.ones_like(global_encode_scales) - ) - global_scale_tensor.copy_(global_encode_scales) + tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - for amax_tensor, scale_tensor, global_scale in zip( - amaxes, scales, global_scale_views + # Use GPU kernel for computing per-block decode scales + # Takes global_amax (not global_scale) and computes scale internally + for amax_tensor, scale_tensor, global_amax_view in zip( + amaxes, scales, [global_amaxes[i : i + 1] for i in range(len(params))] ): - # Use FP32 tensor division to match CUDA kernel exactly. - # CUDA computes: float S_dec_b = block_amax / fp4_max * S_enc; - per_block_decode_scale = torch.clamp( - (amax_tensor / fp4_max_t) * global_scale, max=finfo.max - ) - scale_tensor.copy_(per_block_decode_scale) + tex.nvfp4_compute_per_block_scale(amax_tensor, scale_tensor, global_amax_view.item()) else: global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] @@ -657,19 +642,17 @@ def _cast_master_weights_to_nvfp4_2d( # Always write scales and amax for ALL layers (computed from all-reduced amax). # This ensures scales are correct even for layers not owned by this rank. + # Use GPU kernel to expand tile-level scales to row-level and convert to FP8 tile_rows = tile_shape[0] - expanded_scale = torch.zeros_like(target_scale, dtype=torch.float32) - chunk = block_len - for tile_row_idx in range(tile_rows): - base_row = tile_row_idx * chunk - row_end = min(base_row + chunk, rows) - if base_row >= target_scale.shape[0]: - break - expanded_scale[base_row:row_end, :tile_col_cnt] = per_block_decode_scale[ - tile_row_idx - ] - fp8_view = expanded_scale.to(dtype=torch.float8_e4m3fn).view(torch.uint8) - target_scale.copy_(fp8_view) + rows_padded = target_scale.shape[0] + tex.nvfp4_expand_scale_to_fp8( + per_block_decode_scale, + target_scale, + tile_rows, + tile_col_cnt, + rows_padded, + block_len, + ) if target_amax is not None: target_amax.copy_(global_amaxes[idx : idx + 1]) From 3ab7e2c7f77b8c57975cb1c5c8dd6695db0d8a2f Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 5 Jan 2026 19:05:52 +0000 Subject: [PATCH 26/40] partial cast optimize patch --- .../pytorch/csrc/extensions/transpose.cpp | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 898552e53b..11af519191 100755 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -122,6 +122,61 @@ void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, at::cuda::getCurrentCUDAStream()); } +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, + int64_t tile_rows, int64_t tile_cols, + int64_t rows_padded, int64_t block_len) { + init_extension(); + + // Input: per_block_decode_scale [tile_rows, tile_cols], float32 + // Output: target_scale [rows_padded, tile_cols], uint8 (E4M3) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 expand scale expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 expand scale expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kFloat, "NVFP4 expand scale input must be float32."); + NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 expand scale output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kFloat32); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), + static_cast(tile_rows), + static_cast(tile_cols), + static_cast(rows_padded), + static_cast(block_len), + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, float global_amax) { + init_extension(); + + // block_amax and scale: [tile_rows, tile_cols], float32 + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(scale.scalar_type() == at::kFloat, "Scale must be float32."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto scale_cu = makeTransformerEngineTensor(scale); + + nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), + global_amax, at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { + init_extension(); + + // global_amax and global_scale: [num_params], float32 + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_scale.scalar_type() == at::kFloat, "Global scale must be float32."); + + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_compute_global_scale(global_amax_cu.data(), global_scale_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); From f2986d73e62930cf6f0f9b0a52790940439b992a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Mon, 5 Jan 2026 19:13:59 +0000 Subject: [PATCH 27/40] fix typo --- .../pytorch/distributed/test_cast_master_weights_to_fp8.py | 3 ++- transformer_engine/pytorch/tensor/utils.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 46fb51645a..7d5cff65ef 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1399,4 +1399,5 @@ def test_single_gpu_partial_cast_vs_full(): if __name__ == "__main__": #main() - test_nvfp4_transpose_kernel() \ No newline at end of file + #test_nvfp4_transpose_kernel() + test_single_gpu_partial_cast_vs_full() \ No newline at end of file diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 00f3bd6757..cf793a6a8a 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -225,7 +225,7 @@ def _cast_master_weights_to_fp8_delayed_scaling( # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. - model_weight._reset_caches() + model_weight._reset_caches() quantizer = model_weight._get_quantizer() @@ -371,7 +371,7 @@ def _cast_master_weights_to_fp8_current_scaling( # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. - model_weight._reset_caches() + model_weight._reset_caches() # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. @@ -502,7 +502,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # We cannot create columnwise data here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated at this moment. - model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. From 0d29a197d6341d2dc0e5cd0af26c6571c4300b40 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 6 Jan 2026 23:49:21 +0000 Subject: [PATCH 28/40] reduce kernel launches --- transformer_engine/pytorch/tensor/utils.py | 58 +++++++++++++++------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index cf793a6a8a..c23fdaee9b 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -177,15 +177,40 @@ def cast_master_weights_to_nvfp4( else: use_fsdp_shard_model_weights = True + # Batch convert master_weights to model dtype (single kernel instead of N kernels) + # All NVFP4 model_weights should have the same dtype (BF16) + if len(model_weights) > 0: + target_dtype = model_weights[0].dtype + + # Collect non-None master_weights and their indices + non_none_indices = [] + non_none_weights = [] + sizes = [] + for i, mw in enumerate(master_weights): + if mw is not None: + non_none_indices.append(i) + non_none_weights.append(mw.view(-1)) + sizes.append(mw.numel()) + + if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: + # Concatenate, convert once, then split + concatenated = torch.cat(non_none_weights) + converted = concatenated.to(target_dtype) + split_weights = torch.split(converted, sizes) + + # Rebuild master_weights list with converted tensors + converted_master_weights = list(master_weights) + for idx, split_w, orig_mw in zip(non_none_indices, split_weights, + [master_weights[i] for i in non_none_indices]): + converted_master_weights[idx] = split_w.view(orig_mw.shape) + master_weights = converted_master_weights + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() - if master_weight is not None: - master_weight = master_weight.to(model_weight.dtype) - quantizer = model_weight._get_quantizer() if isinstance(quantizer, NVFP4Quantizer): nvfp4_params.append( @@ -537,8 +562,6 @@ def _cast_master_weights_to_nvfp4_2d( device = params[0][0].device block_len = NVFP4_BLOCK_SCALING_SIZE - dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device) - cu_amax_sizes = [0] tile_shapes: List[tuple[int, int]] = [] row_sizes: List[int] = [] @@ -562,6 +585,7 @@ def _cast_master_weights_to_nvfp4_2d( cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes) packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) + packed_scales = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) amaxes: List[torch.Tensor] = [] scales: List[torch.Tensor] = [] @@ -573,7 +597,7 @@ def _cast_master_weights_to_nvfp4_2d( for i, (model_weight, master_weight, start_offset, _) in enumerate(params): scale_shape = tile_shapes[i] amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) - scale = torch.empty(scale_shape, dtype=torch.float32, device=device) + scale = packed_scales[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) global_amax_view = global_amax_views[i] assert model_weight._rowwise_scale_inv is not None @@ -599,18 +623,16 @@ def _cast_master_weights_to_nvfp4_2d( # Use GPU kernel to compute global encode scales from global amaxes # This replaces multiple Python tensor operations with a single kernel global_scale_tensor = torch.empty_like(global_amaxes) - if len(amaxes) > 0: - tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) - global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - - # Use GPU kernel for computing per-block decode scales - # Takes global_amax (not global_scale) and computes scale internally - for amax_tensor, scale_tensor, global_amax_view in zip( - amaxes, scales, [global_amaxes[i : i + 1] for i in range(len(params))] - ): - tex.nvfp4_compute_per_block_scale(amax_tensor, scale_tensor, global_amax_view.item()) - else: - global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] + + tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) + global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] + + # Use GPU kernel for computing per-block decode scales + # Takes global_amax (not global_scale) and computes scale internally + for amax_tensor, scale_tensor, global_amax_view in zip( + amaxes, scales, [global_amaxes[i : i + 1] for i in range(len(params))] + ): + tex.nvfp4_compute_per_block_scale(amax_tensor, scale_tensor, global_amax_view.item()) zipped_meta = zip( tile_shapes, From 0191035aad057b580677e3a4b67a885428220bf7 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 7 Jan 2026 01:03:58 +0000 Subject: [PATCH 29/40] fused kernel --- .../test_cast_master_weights_to_fp8.py | 3 +- .../include/transformer_engine/recipe.h | 31 +++- transformer_engine/common/recipe/nvfp4.cu | 138 +++++++++++++++++- transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/csrc/extensions/transpose.cpp | 42 +++++- transformer_engine/pytorch/tensor/utils.py | 49 ++++--- 7 files changed, 244 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 7d5cff65ef..be37ffbffe 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1400,4 +1400,5 @@ def test_single_gpu_partial_cast_vs_full(): if __name__ == "__main__": #main() #test_nvfp4_transpose_kernel() - test_single_gpu_partial_cast_vs_full() \ No newline at end of file + #test_single_gpu_partial_cast_vs_full() + test_nvfp4_partial_cast_matches_full() \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 8d0633e4b5..488dd922e9 100755 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -373,11 +373,38 @@ void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, * * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. * \param[out] scale Output scale tensor [tile_rows, tile_cols], float32. - * \param[in] global_amax Global amax value (per-tensor, after all-reduce). + * \param[in] global_amax Global amax tensor (single element), float32. Avoids D2H transfer. * \param[in] stream CUDA stream. */ void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, - float global_amax, cudaStream_t stream); + const NVTETensor global_amax, cudaStream_t stream); + +/*! \brief Fused kernel for NVFP4 scale computation. + * + * Fuses three operations into one kernel: + * 1. Compute per-block decode scales from block amax and global amax + * 2. Copy global amax to target tensor + * 3. Expand tile-level scales to row-level and convert to FP8 E4M3 + * + * Saves 2 kernel launches per parameter. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax tensor [1], float32. + * \param[out] per_block_scale Output per-block scale [tile_rows, tile_cols], float32 (for partial_cast). + * \param[out] target_scale Output scale tensor [rows_padded, tile_cols], uint8 (E4M3). + * \param[out] target_amax Output amax tensor [1], float32 (copy of global_amax). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Total padded rows in output. + * \param[in] block_len Block length (16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream); /*! \brief Compute global encode scale from global amax. * diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index e0a5025f3e..173730d7a4 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -580,7 +580,7 @@ void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, __global__ void nvfp4_compute_per_block_scale_kernel( const float* __restrict__ block_amax, // [tile_rows, tile_cols] float* __restrict__ scale, // [tile_rows, tile_cols] - const float global_amax, + const float* __restrict__ global_amax_ptr, // Pointer to single float value (avoids D2H) const size_t numel ) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -591,6 +591,9 @@ __global__ void nvfp4_compute_per_block_scale_kernel( constexpr float flt_max = 3.402823466e+38f; constexpr float tiny = 1.17549435e-38f; // FLT_MIN + // Read global_amax from device memory (avoids D2H transfer) + float global_amax = *global_amax_ptr; + // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax float safe_global_amax = fmaxf(global_amax, tiny); float global_scale = (global_amax > 0.0f) ? @@ -623,9 +626,11 @@ __global__ void nvfp4_compute_global_scale_kernel( } void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, - float global_amax, cudaStream_t stream) { + const Tensor global_amax, cudaStream_t stream) { NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); size_t numel = block_amax.numel(); if (numel == 0) return; @@ -636,7 +641,8 @@ void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, nvfp4_compute_per_block_scale_kernel<<>>( reinterpret_cast(block_amax.data.dptr), reinterpret_cast(scale.data.dptr), - global_amax, numel); + reinterpret_cast(global_amax.data.dptr), + numel); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -658,6 +664,108 @@ void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, NVTE_CHECK_CUDA(cudaGetLastError()); } +/* + * --------------------------------------------------------------------------- + * FUSED NVFP4 SCALE COMPUTATION KERNEL + * + * Fuses three operations into one kernel: + * 1. nvfp4_compute_per_block_scale: compute tile-level decode scales from block amax + * 2. target_amax.copy_: copy global amax to target tensor + * 3. nvfp4_expand_scale_to_fp8: expand to row-level and convert to FP8 E4M3 + * + * Input (block_amax): [tile_rows, tile_cols] float32 + * Input (global_amax): [1] float32 + * Output (per_block_scale): [tile_rows, tile_cols] float32 (intermediate, for partial_cast) + * Output (target_scale): [rows_padded, tile_cols] uint8 (E4M3) + * Output (target_amax): [1] float32 (copy of global_amax) + * + * Saves 2 kernel launches per parameter (eliminates nvfp4_compute_per_block_scale and + * nvfp4_expand_scale_to_fp8 as separate calls, plus the amax copy). + * --------------------------------------------------------------------------- + */ + __global__ void nvfp4_fused_scale_kernel( + const float* __restrict__ block_amax, // [tile_rows, tile_cols] + const float* __restrict__ global_amax, // [1] + float* __restrict__ per_block_scale, // [tile_rows, tile_cols] - for partial_cast + uint8_t* __restrict__ target_scale, // [rows_padded, tile_cols] + float* __restrict__ target_amax, // [1] + const size_t tile_rows, + const size_t tile_cols, + const size_t rows_padded, + const size_t block_len +) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Read global amax once per thread (broadcast) + const float g_amax = *global_amax; + + // Thread (0,0) copies global_amax to target_amax + if (out_row == 0 && out_col == 0) { + *target_amax = g_amax; + } + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + // Compute the scale value + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + float safe_global_amax = fmaxf(g_amax, tiny); + float global_scale = (g_amax > 0.0f) ? + fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Read block amax and compute per-block decode scale + float amax_val = block_amax[tile_row * tile_cols + out_col]; + scale_val = fminf((amax_val / fp4_max) * global_scale, flt_max); + + // Write per-block scale (only once per tile, when out_row % block_len == 0) + if (out_row % block_len == 0) { + per_block_scale[tile_row * tile_cols + out_col] = scale_val; + } + } + + // Convert float32 to FP8 E4M3 and write expanded scale + fp8e4m3 fp8_val = static_cast(scale_val); + target_scale[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, + Tensor per_block_scale, Tensor target_scale, Tensor target_amax, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.dtype() == DType::kFloat32, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.dtype() == DType::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.dtype() == DType::kFloat32, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, + (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_fused_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(per_block_scale.data.dptr), + reinterpret_cast(target_scale.data.dptr), + reinterpret_cast(target_amax.data.dptr), + tile_rows, tile_cols, rows_padded, block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace nvfp4_recipe } // namespace transformer_engine @@ -673,12 +781,13 @@ void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, } void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, - float global_amax, cudaStream_t stream) { + const NVTETensor global_amax, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); using namespace transformer_engine; nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), *convertNVTETensorCheck(scale), - global_amax, stream); + *convertNVTETensorCheck(global_amax), + stream); } void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, @@ -755,6 +864,19 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r NVTE_CHECK_CUDA(cudaGetLastError()); } - - - +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, + size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_fused_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_fused_scale(*convertNVTETensorCheck(block_amax), + *convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(per_block_scale), + *convertNVTETensorCheck(target_scale), + *convertNVTETensorCheck(target_amax), + tile_rows, tile_cols, rows_padded, block_len, + stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0c209dc77a..04f81bc023 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -164,7 +164,13 @@ void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, int64_t tile_cols, int64_t rows_padded, int64_t block_len); -void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, float global_amax); +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax); + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, + at::Tensor per_block_scale, at::Tensor target_scale, + at::Tensor target_amax, + int64_t tile_rows, int64_t tile_cols, + int64_t rows_padded, int64_t block_len); void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 439566462a..08e020abb9 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -274,6 +274,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Compute global encode scale from global amax", py::arg("global_amax"), py::arg("global_scale"), py::call_guard()); + m.def("nvfp4_fused_scale", &transformer_engine::pytorch::nvfp4_fused_scale, + "Fused kernel: compute per-block decode scale, copy global amax, expand to row-level FP8", + py::arg("block_amax"), py::arg("global_amax"), py::arg("per_block_scale"), + py::arg("target_scale"), py::arg("target_amax"), + py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), py::arg("block_len"), + py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 11af519191..2e178314e0 100755 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -149,18 +149,56 @@ void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, at::cuda::getCurrentCUDAStream()); } -void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, float global_amax) { +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax) { init_extension(); // block_amax and scale: [tile_rows, tile_cols], float32 + // global_amax: single element tensor, float32 (avoids D2H transfer) NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); NVTE_CHECK(scale.scalar_type() == at::kFloat, "Scale must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); auto block_amax_cu = makeTransformerEngineTensor(block_amax); auto scale_cu = makeTransformerEngineTensor(scale); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), - global_amax, at::cuda::getCurrentCUDAStream()); + global_amax_cu.data(), at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, + at::Tensor per_block_scale, at::Tensor target_scale, + at::Tensor target_amax, + int64_t tile_rows, int64_t tile_cols, + int64_t rows_padded, int64_t block_len) { + init_extension(); + + // block_amax: [tile_rows, tile_cols], float32 + // global_amax: [1], float32 + // per_block_scale: [tile_rows, tile_cols], float32 (for partial_cast) + // target_scale: [rows_padded, tile_cols], uint8 (E4M3) + // target_amax: [1], float32 + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), + per_block_scale_cu.data(), target_scale_cu.data(), + target_amax_cu.data(), + static_cast(tile_rows), static_cast(tile_cols), + static_cast(rows_padded), static_cast(block_len), + at::cuda::getCurrentCUDAStream()); } void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index c23fdaee9b..57b07493a9 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -627,13 +627,8 @@ def _cast_master_weights_to_nvfp4_2d( tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - # Use GPU kernel for computing per-block decode scales - # Takes global_amax (not global_scale) and computes scale internally - for amax_tensor, scale_tensor, global_amax_view in zip( - amaxes, scales, [global_amaxes[i : i + 1] for i in range(len(params))] - ): - tex.nvfp4_compute_per_block_scale(amax_tensor, scale_tensor, global_amax_view.item()) - + # Main loop: use fused kernel for scale computation + expansion + amax copy + # This saves 2 kernel launches per parameter zipped_meta = zip( tile_shapes, row_sizes, @@ -641,6 +636,7 @@ def _cast_master_weights_to_nvfp4_2d( scale_targets, amax_targets, params, + amaxes, scales, global_scale_views, ) @@ -651,6 +647,7 @@ def _cast_master_weights_to_nvfp4_2d( target_scale, target_amax, (model_weight, master_weight, start_offset, model_weight_fragment), + block_amax, per_block_decode_scale, global_scale, ) in enumerate(zipped_meta): @@ -662,22 +659,36 @@ def _cast_master_weights_to_nvfp4_2d( # not updated currently. model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) - # Always write scales and amax for ALL layers (computed from all-reduced amax). - # This ensures scales are correct even for layers not owned by this rank. - # Use GPU kernel to expand tile-level scales to row-level and convert to FP8 + # Use fused kernel: computes per-block decode scale, copies global amax to target, + # and expands to row-level FP8 scale - all in one kernel launch tile_rows = tile_shape[0] rows_padded = target_scale.shape[0] - tex.nvfp4_expand_scale_to_fp8( - per_block_decode_scale, - target_scale, - tile_rows, - tile_col_cnt, - rows_padded, - block_len, - ) + global_amax_view = global_amaxes[idx : idx + 1] + # target_amax could be None if model_weight._amax_rowwise is None if target_amax is not None: - target_amax.copy_(global_amaxes[idx : idx + 1]) + tex.nvfp4_fused_scale( + block_amax, + global_amax_view, + per_block_decode_scale, + target_scale, + target_amax, + tile_rows, + tile_col_cnt, + rows_padded, + block_len, + ) + else: + # Fallback: compute scale and expand without amax copy + tex.nvfp4_compute_per_block_scale(block_amax, per_block_decode_scale, global_amax_view) + tex.nvfp4_expand_scale_to_fp8( + per_block_decode_scale, + target_scale, + tile_rows, + tile_col_cnt, + rows_padded, + block_len, + ) # Only cast data for layers owned by this rank if master_weight is None or master_weight.numel() == 0: From 27d9465b6fea3ef2cb908ec8472e748aa2145f2e Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 8 Jan 2026 04:39:13 +0000 Subject: [PATCH 30/40] optimize transpose kernel --- transformer_engine/common/recipe/nvfp4.cu | 153 ++++++++++++++++------ 1 file changed, 114 insertions(+), 39 deletions(-) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 173730d7a4..ec929af938 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -338,51 +338,124 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, * - After transpose: elements [k, 2*m_packed] and [k, 2*m_packed+1] share a byte * which were originally [2*m_packed, k] and [2*m_packed+1, k] * - * So we need to read from two consecutive rows of the input, extract the same - * nibble position from each, and pack them into one output byte. + * Two implementations: + * 1. TMA version (SM100+/Hopper): Uses bulk async copy for ~3x bandwidth + * 2. Vectorized version (fallback): Uses uint2 loads/stores + * + * Tile size: 64x64 logical FP4 elements = 64x32 input bytes = 64x32 output bytes * --------------------------------------------------------------------------- */ -__global__ void nvfp4_transpose_kernel(const uint8_t *input, uint8_t *output, const size_t M, - const size_t K) { - // Input: [M, K] logical, stored as [M, K/2] bytes - // Output: [K, M] logical, stored as [K, M/2] bytes - - const size_t K_packed = K / 2; // input packed width - const size_t M_packed = M / 2; // output packed width - - const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; // k index [0, K) - const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; // packed m index [0, M/2) +// Vectorized transpose kernel parameters +constexpr int TRANSPOSE_TILE_DIM = 64; // Logical FP4 elements per tile dimension +constexpr int TRANSPOSE_TILE_PACKED = 32; // TILE_DIM / 2 bytes +constexpr int TRANSPOSE_BLOCK_SIZE = 256; // threads per block - if (out_row >= K || out_col >= M_packed) return; +// Shared memory: store unpacked 4-bit values as bytes for easy transpose +// Size: TILE_DIM x (TILE_DIM + 4) to avoid bank conflicts +constexpr int TRANSPOSE_SHMEM_STRIDE = TRANSPOSE_TILE_DIM + 4; - // The two logical M positions this output byte covers - const size_t m0 = out_col * 2; // for low nibble of output - const size_t m1 = out_col * 2 + 1; // for high nibble of output - const size_t k = out_row; +/* + * Vectorized transpose kernel with uint2 loads/stores (256 threads) + * Tile: 64x64 logical FP4 = 64x32 packed bytes + */ +__global__ void __launch_bounds__(TRANSPOSE_BLOCK_SIZE) +nvfp4_transpose_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const size_t M, const size_t K) { + const size_t K_packed = K / 2; + const size_t M_packed = M / 2; - // In the input, both elements are at the same packed column - const size_t in_col = k / 2; - const int nibble_idx = k & 1; // 0 = low nibble, 1 = high nibble + const size_t tile_m_start = blockIdx.x * TRANSPOSE_TILE_DIM; + const size_t tile_k_start = blockIdx.y * TRANSPOSE_TILE_DIM; - // Read the two input bytes from consecutive rows - const uint8_t in_byte_0 = input[m0 * K_packed + in_col]; - const uint8_t in_byte_1 = input[m1 * K_packed + in_col]; + __shared__ uint8_t shmem[TRANSPOSE_TILE_DIM][TRANSPOSE_SHMEM_STRIDE]; - // Extract the appropriate nibbles - uint8_t val0, val1; - if (nibble_idx == 0) { - val0 = in_byte_0 & 0x0Fu; - val1 = in_byte_1 & 0x0Fu; - } else { - val0 = (in_byte_0 >> 4) & 0x0Fu; - val1 = (in_byte_1 >> 4) & 0x0Fu; + const int tid = threadIdx.x; + + // Phase 1: Load input tile with VECTORIZED uint2 reads + // 256 threads, each loads 8 bytes (uint2) = 2048 bytes total + // Input tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // 64 rows, 4 threads per row + const int thread_col = (tid % 4) * 8; // 4 x 8 = 32 bytes per row + + const size_t global_m = tile_m_start + thread_row; + const size_t global_k_packed_base = tile_k_start / 2 + thread_col; + + // Load 8 bytes as uint2 + uint2 loaded = make_uint2(0, 0); + if (global_m < M && global_k_packed_base + 7 < K_packed) { + loaded = *reinterpret_cast(&input[global_m * K_packed + global_k_packed_base]); + } else if (global_m < M) { + // Boundary: scalar loads + uint8_t* bytes = reinterpret_cast(&loaded); + #pragma unroll + for (int b = 0; b < 8; ++b) { + size_t col = global_k_packed_base + b; + bytes[b] = (col < K_packed) ? input[global_m * K_packed + col] : 0; + } + } + + // Unpack 8 bytes -> 16 nibbles and store to shared memory + const uint8_t* bytes = reinterpret_cast(&loaded); + #pragma unroll + for (int b = 0; b < 8; ++b) { + const int k0 = thread_col * 2 + b * 2; + const int k1 = k0 + 1; + shmem[thread_row][k0] = bytes[b] & 0x0F; + shmem[thread_row][k1] = (bytes[b] >> 4) & 0x0F; + } } - // Pack: val0 in low nibble, val1 in high nibble - const uint8_t out_byte = val0 | static_cast(val1 << 4); + __syncthreads(); - output[out_row * M_packed + out_col] = out_byte; + // Phase 2: Write output with VECTORIZED uint2 stores + // Output tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // output K dimension [0, 64) + const int thread_col_base = (tid % 4) * 8; // output M_packed [0, 32) in steps of 8 + + const size_t global_k = tile_k_start + thread_row; + const size_t global_m_packed_base = tile_m_start / 2 + thread_col_base; + + if (global_k >= K) return; + + // Build 8 output bytes in registers + uint8_t out_bytes[8]; + + #pragma unroll + for (int b = 0; b < 8; ++b) { + const int out_m_packed = thread_col_base + b; + + if (global_m_packed_base + b >= M_packed) { + out_bytes[b] = 0; + continue; + } + + // Two M positions that pack into this output byte + const int m0 = out_m_packed * 2; + const int m1 = out_m_packed * 2 + 1; + const int k = thread_row; + + // Read from shared memory (transposed access) + const uint8_t val0 = shmem[m0][k]; + const uint8_t val1 = shmem[m1][k]; + + out_bytes[b] = val0 | (val1 << 4); + } + + // Vectorized store as uint2 + if (global_m_packed_base + 7 < M_packed) { + *reinterpret_cast(&output[global_k * M_packed + global_m_packed_base]) = + *reinterpret_cast(out_bytes); + } else { + // Boundary: scalar stores + for (int b = 0; b < 8 && global_m_packed_base + b < M_packed; ++b) { + output[global_k * M_packed + global_m_packed_base + b] = out_bytes[b]; + } + } + } } void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { @@ -412,14 +485,16 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { if (M == 0 || K == 0) return; - // Launch kernel - constexpr int kBlockDim = 16; - dim3 block(kBlockDim, kBlockDim); - dim3 grid((M_packed + kBlockDim - 1) / kBlockDim, (K + kBlockDim - 1) / kBlockDim); - + // Use vectorized kernel (faster than TMA for pure transpose) + // 128x128 tiles with 512 threads and uint4 vectorized access + dim3 block(TRANSPOSE_BLOCK_SIZE); + dim3 grid((M + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM, + (K + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM); + nvfp4_transpose_kernel<<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(output.data.dptr), M, K); + NVTE_CHECK_CUDA(cudaGetLastError()); } From 2b45a3f769a5bb0204ea5a0e00e3234dca830086 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 8 Jan 2026 17:37:17 +0000 Subject: [PATCH 31/40] multi-tensor apply for transpose --- .../test_cast_master_weights_to_fp8.py | 6 +- transformer_engine/pytorch/csrc/extensions.h | 8 ++ .../pytorch/csrc/extensions/pybind.cpp | 7 ++ .../pytorch/csrc/extensions/transpose.cpp | 67 ++++++++++++++ transformer_engine/pytorch/tensor/utils.py | 89 ++++++++++++++++++- 5 files changed, 172 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index be37ffbffe..0f6d2aaa57 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1037,7 +1037,7 @@ def test_nvfp4_transpose_kernel() -> None: torch.manual_seed(1234) device = torch.device("cuda") - shape = (2048, 64) + shape = (2048, 5120) master_weight = torch.randn(shape, dtype=torch.float32, device=device) print("\n=== Testing NVFP4 transpose kernel ===") @@ -1399,6 +1399,6 @@ def test_single_gpu_partial_cast_vs_full(): if __name__ == "__main__": #main() - #test_nvfp4_transpose_kernel() + test_nvfp4_transpose_kernel() #test_single_gpu_partial_cast_vs_full() - test_nvfp4_partial_cast_matches_full() \ No newline at end of file + #test_nvfp4_partial_cast_matches_full() \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 04f81bc023..f1bee55232 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -160,6 +160,14 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output = void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, int64_t K_tiles); +void nvfp4_multi_tensor_create_columnwise( + std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, + std::vector K_list); + void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, int64_t tile_cols, int64_t rows_padded, int64_t block_len); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 08e020abb9..32905e2a4a 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -280,6 +280,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); + m.def("nvfp4_multi_tensor_create_columnwise", + &transformer_engine::pytorch::nvfp4_multi_tensor_create_columnwise, + "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", + py::arg("rowwise_data_list"), py::arg("columnwise_data_list"), + py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), + py::arg("M_list"), py::arg("K_list"), + py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 2e178314e0..435a5e326d 100755 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -240,5 +240,72 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { return std::move(*out); } +void nvfp4_multi_tensor_create_columnwise( + std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, + std::vector K_list) { + init_extension(); + + const size_t num_tensors = rowwise_data_list.size(); + NVTE_CHECK(columnwise_data_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(rowwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(columnwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(M_list.size() == num_tensors, "M_list size mismatch"); + NVTE_CHECK(K_list.size() == num_tensors, "K_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + // Process each tensor - the main benefit is reduced Python overhead + // by doing the iteration in C++ rather than Python + constexpr size_t TILE_SIZE = 16; + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& rowwise_data = rowwise_data_list[i]; + auto& columnwise_data = columnwise_data_list[i]; + const auto& rowwise_scale_inv = rowwise_scale_inv_list[i]; + auto& columnwise_scale_inv = columnwise_scale_inv_list[i]; + const int64_t M = M_list[i]; + const int64_t K = K_list[i]; + + // Transpose data: [M, K/2] -> [K, M/2] + const auto data_shape = getTensorShape(rowwise_data); + NVTE_CHECK(data_shape.size() == 2, "NVFP4 data must be 2D."); + const size_t M_packed = static_cast(M) / 2; + const size_t K_packed = data_shape[1]; + + auto input_cu = makeTransformerEngineTensor( + rowwise_data.data_ptr(), std::vector{static_cast(M), K_packed}, + DType::kByte); + auto output_cu = makeTransformerEngineTensor( + columnwise_data.data_ptr(), std::vector{static_cast(K), M_packed}, + DType::kByte); + nvte_nvfp4_transpose(input_cu.data(), output_cu.data(), stream); + + // Transpose scales + const size_t M_tiles = (static_cast(M) + TILE_SIZE - 1) / TILE_SIZE; + const size_t K_tiles = (static_cast(K) + TILE_SIZE - 1) / TILE_SIZE; + + const auto scale_in_shape = getTensorShape(rowwise_scale_inv); + const auto scale_out_shape = getTensorShape(columnwise_scale_inv); + + auto scale_input_cu = makeTransformerEngineTensor( + rowwise_scale_inv.data_ptr(), + std::vector{scale_in_shape[0], scale_in_shape[1]}, DType::kByte); + auto scale_output_cu = makeTransformerEngineTensor( + columnwise_scale_inv.data_ptr(), + std::vector{scale_out_shape[0], scale_out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), + M_tiles, K_tiles, stream); + } +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 57b07493a9..5732d64b2f 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -845,9 +845,15 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten - Float8Tensor: may need to create a transposed view to match backend GEMM. - Float8BlockwiseQTensor: create column-wise storage. - Plain pytorch tensor: noop. + + For NVFP4 tensors, uses batched multi-tensor processing to reduce CPU overhead. """ if not isinstance(model_weights, list): model_weights = [model_weights] + + # Collect NVFP4 tensors for batched processing + nvfp4_tensors = [] + for model_weight in model_weights: if isinstance(model_weight, Float8Tensor): # Delayed scaling and per-tensor current scaling: if backend does not support @@ -858,13 +864,92 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() elif isinstance(model_weight, NVFP4Tensor): - # NVFP4 scaling: create column-wise storage. - model_weight._create_columnwise() + # Collect for batched processing + nvfp4_tensors.append(model_weight) elif isinstance(model_weight, MXFP8Tensor): # MXFP8 scaling: no need to do anything. pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") + + # Batch process all NVFP4 tensors with multi-tensor approach + if nvfp4_tensors: + _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors) + + +def _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors: List[NVFP4Tensor]): + """ + Batched columnwise creation for multiple NVFP4 tensors. + Reduces CPU overhead by collecting all tensor metadata and dispatching to C++. + """ + TILE_SIZE = 16 + + # Prepare tensor lists for batched C++ call + rowwise_data_list = [] + columnwise_data_list = [] + rowwise_scale_inv_list = [] + columnwise_scale_inv_list = [] + M_list = [] + K_list = [] + + for tensor in nvfp4_tensors: + rowwise_data = tensor._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + tensor._rowwise_data = rowwise_data + + logical_shape = tensor.size() + M, K = logical_shape[0], logical_shape[-1] + M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE + + # Allocate columnwise_data if needed + if tensor._columnwise_data is None: + # Output shape: [K, M/2] packed bytes + columnwise_data = torch.empty( + (K, M // 2), + dtype=torch.uint8, + device=rowwise_data.device, + ) + tensor._columnwise_data = columnwise_data + else: + columnwise_data = tensor._columnwise_data + + # Allocate columnwise_scale_inv if needed + if tensor._columnwise_scale_inv is None: + assert tensor._quantizer is not None + columnwise_scale_inv_shape = tensor._quantizer.get_scale_shape(logical_shape, True) + columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=tensor._rowwise_scale_inv.dtype, + device=tensor._rowwise_scale_inv.device, + ) + tensor._columnwise_scale_inv = columnwise_scale_inv + else: + columnwise_scale_inv = tensor._columnwise_scale_inv + + rowwise_data_list.append(rowwise_data) + columnwise_data_list.append(columnwise_data) + rowwise_scale_inv_list.append(tensor._rowwise_scale_inv) + columnwise_scale_inv_list.append(columnwise_scale_inv) + M_list.append(M) + K_list.append(K) + + # Copy amax if needed + if tensor._amax_columnwise is None and tensor._amax_rowwise is not None: + tensor._amax_columnwise = tensor._amax_rowwise.clone() + elif tensor._amax_rowwise is not None: + tensor._amax_columnwise.copy_(tensor._amax_rowwise) + + # Dispatch to C++ multi-tensor kernel + tex.nvfp4_multi_tensor_create_columnwise( + rowwise_data_list, + columnwise_data_list, + rowwise_scale_inv_list, + columnwise_scale_inv_list, + M_list, + K_list, + ) def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: From 0dff43913261a1641811cfe81871ca0830986626 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 8 Jan 2026 18:11:09 +0000 Subject: [PATCH 32/40] multi-tensor for amax --- transformer_engine/pytorch/csrc/extensions.h | 9 +++ .../csrc/extensions/nvfp4_2d_partial_cast.cpp | 62 +++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 7 +++ transformer_engine/pytorch/tensor/utils.py | 32 ++++++++-- 4 files changed, 105 insertions(+), 5 deletions(-) mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f1bee55232..ecf91d9833 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -168,6 +168,15 @@ void nvfp4_multi_tensor_create_columnwise( std::vector M_list, std::vector K_list); +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, + std::vector partial_amax_list, + std::vector global_amax_list, + std::vector h_list, + std::vector w_list, + std::vector start_offset_list, + int64_t block_len); + void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, int64_t tile_cols, int64_t rows_padded, int64_t block_len); diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp old mode 100644 new mode 100755 index f41a6eed29..9e994f1c46 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -50,6 +50,68 @@ void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tens at::cuda::getCurrentCUDAStream()); } +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, + std::vector partial_amax_list, + std::vector global_amax_list, + std::vector h_list, + std::vector w_list, + std::vector start_offset_list, + int64_t block_len) { + + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = master_weight_list.size(); + TORCH_CHECK(partial_amax_list.size() == num_tensors, "partial_amax_list size mismatch"); + TORCH_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& master_weight = master_weight_list[i]; + auto& partial_amax = partial_amax_list[i]; + auto& global_amax = global_amax_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(partial_amax.dim() == 2, "partial_amax must be a 2D tensor"); + TORCH_CHECK(partial_amax.scalar_type() == at::ScalarType::Float, + "partial_amax must be a float tensor"); + TORCH_CHECK(master_weight.scalar_type() == at::ScalarType::Float || + master_weight.scalar_type() == at::ScalarType::BFloat16, + "master_weight must be a float or bfloat16 tensor"); + TORCH_CHECK(global_amax.scalar_type() == at::ScalarType::Float, + "global_amax must be a float tensor"); + TORCH_CHECK(global_amax.numel() == 1, "global_amax must have exactly one element"); + + // Compute partial amax (per-block amax) + const TensorWrapper tensor_cu = makeTransformerEngineTensor(master_weight.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(partial_amax); + + nvte_nvfp4_2d_compute_partial_amax( + tensor_cu.data(), amax_cu.data(), h, w, + partial_amax.stride(0), partial_amax.stride(1), + start_offset, static_cast(block_len), stream); + + // Compute global amax + auto* global_amax_ptr = global_amax.data_ptr(); + TensorWrapper fake_te_output( + /*dptr=*/nullptr, tensor_cu.shape(), + DType::kFloat32, + global_amax_ptr); + + nvte_compute_amax(tensor_cu.data(), fake_te_output.data(), stream); + } +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 32905e2a4a..9c6ea2ba70 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -315,6 +315,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Compute partial amax from master weights for NVFP4 2D", py::arg("tensor"), py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, py::call_guard()); + m.def("nvfp4_multi_tensor_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_multi_tensor_compute_partial_amax, + "Batched compute partial and global amax from master weights for NVFP4 2D", + py::arg("master_weight_list"), py::arg("partial_amax_list"), py::arg("global_amax_list"), + py::arg("h_list"), py::arg("w_list"), py::arg("start_offset_list"), + py::arg("block_len") = 16, + py::call_guard()); m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 5732d64b2f..47f1022236 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -594,6 +594,14 @@ def _cast_master_weights_to_nvfp4_2d( global_amaxes[i : i + 1] for i in range(len(params)) ] + # Collect tensors for batched multi-tensor amax computation + master_weight_list: List[torch.Tensor] = [] + partial_amax_list: List[torch.Tensor] = [] + global_amax_list: List[torch.Tensor] = [] + h_list: List[int] = [] + w_list: List[int] = [] + start_offset_list: List[int] = [] + for i, (model_weight, master_weight, start_offset, _) in enumerate(params): scale_shape = tile_shapes[i] amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) @@ -608,11 +616,25 @@ def _cast_master_weights_to_nvfp4_2d( if master_weight is not None and master_weight.numel() > 0: assert len(model_weight.shape) == 2 h, w = model_weight.shape - # master_weight is already converted to model_weight.dtype (BF16) in the caller - tex.nvfp4_2d_compute_partial_amax( - master_weight, amax, h, w, start_offset, block_len - ) - tex.compute_amax(master_weight, global_amax_view) + # Collect for batched processing + master_weight_list.append(master_weight) + partial_amax_list.append(amax) + global_amax_list.append(global_amax_view) + h_list.append(h) + w_list.append(w) + start_offset_list.append(start_offset) + + # Batched multi-tensor call for partial and global amax computation + if master_weight_list: + tex.nvfp4_multi_tensor_compute_partial_amax( + master_weight_list, + partial_amax_list, + global_amax_list, + h_list, + w_list, + start_offset_list, + block_len, + ) if packed_amaxes.numel() > 0: torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) From dc3d495377286cf341b2e93b3610c5b6c87c7df2 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 8 Jan 2026 18:28:40 +0000 Subject: [PATCH 33/40] multi-tensor for partial cast --- transformer_engine/pytorch/csrc/extensions.h | 21 ++++ .../csrc/extensions/nvfp4_2d_partial_cast.cpp | 54 +++++++++ .../pytorch/csrc/extensions/pybind.cpp | 15 +++ .../pytorch/csrc/extensions/transpose.cpp | 59 ++++++++++ transformer_engine/pytorch/tensor/utils.py | 111 ++++++++++++------ 5 files changed, 222 insertions(+), 38 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ecf91d9833..d4c6a4d737 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -189,6 +189,17 @@ void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, int64_t tile_rows, int64_t tile_cols, int64_t rows_padded, int64_t block_len); +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, + std::vector global_amax_list, + std::vector per_block_scale_list, + std::vector target_scale_list, + std::vector target_amax_list, + std::vector tile_rows_list, + std::vector tile_cols_list, + std::vector rows_padded_list, + int64_t block_len); + void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); @@ -376,6 +387,16 @@ void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, si void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, size_t block_len); + +void nvfp4_multi_tensor_2d_partial_cast( + std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, + std::vector w_list, + std::vector start_offset_list, + int64_t block_len); void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, at::Tensor amax_colwise, int rows, int cols, size_t start_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp index 9e994f1c46..2c73c577e2 100755 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -50,6 +50,60 @@ void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tens at::cuda::getCurrentCUDAStream()); } +void nvfp4_multi_tensor_2d_partial_cast( + std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, + std::vector w_list, + std::vector start_offset_list, + int64_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = inp_list.size(); + TORCH_CHECK(out_list.size() == num_tensors, "out_list size mismatch"); + TORCH_CHECK(scale_list.size() == num_tensors, "scale_list size mismatch"); + TORCH_CHECK(global_scale_list.size() == num_tensors, "global_scale_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& inp = inp_list[i]; + const auto& out = out_list[i]; + const auto& scale = scale_list[i]; + const auto& global_scale = global_scale_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); + TORCH_CHECK(inp.scalar_type() == at::ScalarType::Float || + inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), + global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), + start_offset, static_cast(block_len), stream); + } +} + void nvfp4_multi_tensor_compute_partial_amax( std::vector master_weight_list, std::vector partial_amax_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 9c6ea2ba70..5f0e2d9cf5 100755 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -280,6 +280,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); + m.def("nvfp4_multi_tensor_fused_scale", + &transformer_engine::pytorch::nvfp4_multi_tensor_fused_scale, + "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for multiple tensors", + py::arg("block_amax_list"), py::arg("global_amax_list"), py::arg("per_block_scale_list"), + py::arg("target_scale_list"), py::arg("target_amax_list"), + py::arg("tile_rows_list"), py::arg("tile_cols_list"), py::arg("rows_padded_list"), + py::arg("block_len"), + py::call_guard()); m.def("nvfp4_multi_tensor_create_columnwise", &transformer_engine::pytorch::nvfp4_multi_tensor_create_columnwise, "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", @@ -326,6 +334,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, py::call_guard()); + m.def("nvfp4_multi_tensor_2d_partial_cast", + &transformer_engine::pytorch::nvfp4_multi_tensor_2d_partial_cast, + "Batched partial cast from master weights for NVFP4 2D", + py::arg("inp_list"), py::arg("out_list"), py::arg("scale_list"), + py::arg("global_scale_list"), py::arg("h_list"), py::arg("w_list"), + py::arg("start_offset_list"), py::arg("block_len") = 16, + py::call_guard()); m.def("mxfp8_scaling_compute_partial_amax", &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, "Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"), diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 435a5e326d..49d0c76c54 100755 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -201,6 +201,65 @@ void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::cuda::getCurrentCUDAStream()); } +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, + std::vector global_amax_list, + std::vector per_block_scale_list, + std::vector target_scale_list, + std::vector target_amax_list, + std::vector tile_rows_list, + std::vector tile_cols_list, + std::vector rows_padded_list, + int64_t block_len) { + init_extension(); + + const size_t num_tensors = block_amax_list.size(); + NVTE_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + NVTE_CHECK(per_block_scale_list.size() == num_tensors, "per_block_scale_list size mismatch"); + NVTE_CHECK(target_scale_list.size() == num_tensors, "target_scale_list size mismatch"); + NVTE_CHECK(target_amax_list.size() == num_tensors, "target_amax_list size mismatch"); + NVTE_CHECK(tile_rows_list.size() == num_tensors, "tile_rows_list size mismatch"); + NVTE_CHECK(tile_cols_list.size() == num_tensors, "tile_cols_list size mismatch"); + NVTE_CHECK(rows_padded_list.size() == num_tensors, "rows_padded_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& block_amax = block_amax_list[i]; + const auto& global_amax = global_amax_list[i]; + auto& per_block_scale = per_block_scale_list[i]; + auto& target_scale = target_scale_list[i]; + auto& target_amax = target_amax_list[i]; + const size_t tile_rows = static_cast(tile_rows_list[i]); + const size_t tile_cols = static_cast(tile_cols_list[i]); + const size_t rows_padded = static_cast(rows_padded_list[i]); + + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), + per_block_scale_cu.data(), target_scale_cu.data(), + target_amax_cu.data(), + tile_rows, tile_cols, rows_padded, + static_cast(block_len), stream); + } +} + void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { init_extension(); diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 47f1022236..1dd4b0b993 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -649,8 +649,26 @@ def _cast_master_weights_to_nvfp4_2d( tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] - # Main loop: use fused kernel for scale computation + expansion + amax copy - # This saves 2 kernel launches per parameter + # Collect tensors for batched fused scale kernel + fused_scale_block_amax_list: List[torch.Tensor] = [] + fused_scale_global_amax_list: List[torch.Tensor] = [] + fused_scale_per_block_scale_list: List[torch.Tensor] = [] + fused_scale_target_scale_list: List[torch.Tensor] = [] + fused_scale_target_amax_list: List[torch.Tensor] = [] + fused_scale_tile_rows_list: List[int] = [] + fused_scale_tile_cols_list: List[int] = [] + fused_scale_rows_padded_list: List[int] = [] + + # Collect tensors for batched partial cast kernel + partial_cast_inp_list: List[torch.Tensor] = [] + partial_cast_out_list: List[torch.Tensor] = [] + partial_cast_scale_list: List[torch.Tensor] = [] + partial_cast_global_scale_list: List[torch.Tensor] = [] + partial_cast_h_list: List[int] = [] + partial_cast_w_list: List[int] = [] + partial_cast_start_offset_list: List[int] = [] + + # First pass: collect all tensors and update usage zipped_meta = zip( tile_shapes, row_sizes, @@ -681,27 +699,22 @@ def _cast_master_weights_to_nvfp4_2d( # not updated currently. model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) - # Use fused kernel: computes per-block decode scale, copies global amax to target, - # and expands to row-level FP8 scale - all in one kernel launch tile_rows = tile_shape[0] rows_padded = target_scale.shape[0] global_amax_view = global_amaxes[idx : idx + 1] - - # target_amax could be None if model_weight._amax_rowwise is None + + # Collect for fused scale kernel (only if target_amax is not None) if target_amax is not None: - tex.nvfp4_fused_scale( - block_amax, - global_amax_view, - per_block_decode_scale, - target_scale, - target_amax, - tile_rows, - tile_col_cnt, - rows_padded, - block_len, - ) + fused_scale_block_amax_list.append(block_amax) + fused_scale_global_amax_list.append(global_amax_view) + fused_scale_per_block_scale_list.append(per_block_decode_scale) + fused_scale_target_scale_list.append(target_scale) + fused_scale_target_amax_list.append(target_amax) + fused_scale_tile_rows_list.append(tile_rows) + fused_scale_tile_cols_list.append(tile_col_cnt) + fused_scale_rows_padded_list.append(rows_padded) else: - # Fallback: compute scale and expand without amax copy + # Fallback: compute scale and expand without amax copy (rare case) tex.nvfp4_compute_per_block_scale(block_amax, per_block_decode_scale, global_amax_view) tex.nvfp4_expand_scale_to_fp8( per_block_decode_scale, @@ -712,27 +725,49 @@ def _cast_master_weights_to_nvfp4_2d( block_len, ) - # Only cast data for layers owned by this rank - if master_weight is None or master_weight.numel() == 0: - continue + # Collect for partial cast kernel (only for layers owned by this rank) + if master_weight is not None and master_weight.numel() > 0: + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + rowwise_bytes = model_weight._rowwise_data.view(-1) + byte_start = start_offset // 2 + byte_end = (end_offset + 1) // 2 + model_weight_fragment = rowwise_bytes[byte_start:byte_end] + assert len(model_weight.shape) == 2 + h, w = model_weight.shape - end_offset = start_offset + master_weight.numel() - if not use_fsdp_shard_model_weights: - rowwise_bytes = model_weight._rowwise_data.view(-1) - byte_start = start_offset // 2 - byte_end = (end_offset + 1) // 2 - model_weight_fragment = rowwise_bytes[byte_start:byte_end] - assert len(model_weight.shape) == 2 - h, w = model_weight.shape - # master_weight is already converted to model_weight.dtype (BF16) in the caller - tex.nvfp4_2d_partial_cast( - master_weight, - model_weight_fragment, - per_block_decode_scale, - global_scale, - h, - w, - start_offset, + partial_cast_inp_list.append(master_weight) + partial_cast_out_list.append(model_weight_fragment) + partial_cast_scale_list.append(per_block_decode_scale) + partial_cast_global_scale_list.append(global_scale) + partial_cast_h_list.append(h) + partial_cast_w_list.append(w) + partial_cast_start_offset_list.append(start_offset) + + # Batched multi-tensor call for fused scale + if fused_scale_block_amax_list: + tex.nvfp4_multi_tensor_fused_scale( + fused_scale_block_amax_list, + fused_scale_global_amax_list, + fused_scale_per_block_scale_list, + fused_scale_target_scale_list, + fused_scale_target_amax_list, + fused_scale_tile_rows_list, + fused_scale_tile_cols_list, + fused_scale_rows_padded_list, + block_len, + ) + + # Batched multi-tensor call for partial cast + if partial_cast_inp_list: + tex.nvfp4_multi_tensor_2d_partial_cast( + partial_cast_inp_list, + partial_cast_out_list, + partial_cast_scale_list, + partial_cast_global_scale_list, + partial_cast_h_list, + partial_cast_w_list, + partial_cast_start_offset_list, block_len, ) From deda92e1845fb6e9527a2aa3507d524bc735b0a7 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 21 Jan 2026 06:21:28 +0000 Subject: [PATCH 34/40] clean up --- .../test_cast_master_weights_to_fp8.py | 177 +----------------- .../common/gemm/cublaslt_gemm.cu | 3 +- .../include/transformer_engine/recipe.h | 37 +++- transformer_engine/common/recipe/nvfp4.cu | 13 +- transformer_engine/pytorch/module/base.py | 1 + .../tensor/storage/nvfp4_tensor_storage.py | 7 +- transformer_engine/pytorch/tensor/utils.py | 18 +- 7 files changed, 52 insertions(+), 204 deletions(-) mode change 100644 => 100755 transformer_engine/common/gemm/cublaslt_gemm.cu mode change 100644 => 100755 transformer_engine/pytorch/module/base.py diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index e7d41e81e4..257ff9f0c8 100755 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -878,23 +878,6 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi te.Linear(256 * 3, 128, **linear_kwargs), ) - # Use 2048x2048 weights shape for testing. - # with te.quantized_model_init( - # enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True - # ): - # model_nvfp4 = nn.Sequential( - # te.Linear(2048, 2048, **linear_kwargs), - # te.Linear(2048, 2048, **linear_kwargs), - # te.Linear(2048, 2048, **linear_kwargs), - # ) - - # # BF16 model (created outside quantized_model_init) - # model = nn.Sequential( - # te.Linear(2048, 2048, **linear_kwargs), - # te.Linear(2048, 2048, **linear_kwargs), - # te.Linear(2048, 2048, **linear_kwargs), - # ) - for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): high_precision_init_val = w_nvfp4.get_high_precision_init_val() w.data.copy_(high_precision_init_val) @@ -913,7 +896,6 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi w_nvfp4.main_grad.zero_() w.main_grad.zero_() - # Original input shape: torch.randn(128, 128, ...) inputs = [ torch.randn(2048, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] @@ -972,7 +954,6 @@ def run_parallel_tests() -> None: quantizations = [] if is_fp8_available(): - print("fp8 available") quantizations.extend(["fp8", "fp8_cs"]) if is_fp8_block_scaling_available(): quantizations.append("fp8_block") @@ -990,7 +971,6 @@ def run_parallel_tests() -> None: print("starting cast master weights to nvfp4 test") nvfp4_available, _ = is_nvfp4_available(return_reason=True) if nvfp4_available: - #for post_ag_processing in manual_post_all_gather_processings: _test_cast_master_weights_to_nvfp4(dp_group, False) dist.destroy_process_group() @@ -1052,14 +1032,6 @@ def test_nvfp4_transpose_kernel() -> None: reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() reference_columnwise_scale_inv = reference_tensor._columnwise_scale_inv.detach().clone() reference_columnwise_amax = reference_tensor._amax_columnwise.detach().clone() if reference_tensor._amax_columnwise is not None else None - print( - "reference columnwise_data shape:", - reference_columnwise_data.shape, - ) - print( - "reference columnwise_scale_inv shape:", - reference_columnwise_scale_inv.shape, - ) # Create tensor with only rowwise data, then call _create_columnwise() quantizer_rowwise_only = NVFP4Quantizer( @@ -1072,14 +1044,6 @@ def test_nvfp4_transpose_kernel() -> None: test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) assert test_tensor._columnwise_data is not None, "Test tensor should have columnwise data after _create_columnwise()" assert test_tensor._columnwise_scale_inv is not None, "Test tensor should have columnwise scale_inv after _create_columnwise()" - print( - "test_tensor columnwise_data shape after transpose:", - test_tensor._columnwise_data.shape, - ) - print( - "test_tensor columnwise_scale_inv shape after transpose:", - test_tensor._columnwise_scale_inv.shape, - ) # Compare columnwise data - should be bitwise identical torch.testing.assert_close( @@ -1089,22 +1053,6 @@ def test_nvfp4_transpose_kernel() -> None: rtol=0, msg="NVFP4 transpose kernel produced different columnwise data than reference!", ) - print("columnwise_data matches!") - - # Compare columnwise scale_inv - should be bitwise identical - print("reference columnwise_scale_inv:\n", reference_columnwise_scale_inv) - print("test columnwise_scale_inv:\n", test_tensor._columnwise_scale_inv) - print("reference rowwise_scale_inv shape:", reference_tensor._rowwise_scale_inv.shape) - print("test rowwise_scale_inv shape:", test_tensor._rowwise_scale_inv.shape) - - # Check if they match - scale_match = torch.equal(test_tensor._columnwise_scale_inv, reference_columnwise_scale_inv) - if not scale_match: - diff_mask = test_tensor._columnwise_scale_inv != reference_columnwise_scale_inv - print("Number of mismatches:", diff_mask.sum().item()) - print("Mismatch locations:", diff_mask.nonzero()[:10]) - print("Test values at mismatch:", test_tensor._columnwise_scale_inv[diff_mask][:10]) - print("Reference values at mismatch:", reference_columnwise_scale_inv[diff_mask][:10]) torch.testing.assert_close( test_tensor._columnwise_scale_inv, @@ -1113,21 +1061,14 @@ def test_nvfp4_transpose_kernel() -> None: rtol=0, msg="NVFP4 _create_columnwise produced different columnwise scale_inv than reference!", ) - print("columnwise_scale_inv matches!") - - # Compare columnwise amax if available - if reference_columnwise_amax is not None: - torch.testing.assert_close( - test_tensor._amax_columnwise, - reference_columnwise_amax, - atol=0, - rtol=0, - msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", - ) - print("columnwise_amax matches!") - - print("NVFP4 transpose kernel test PASSED!") + torch.testing.assert_close( + test_tensor._amax_columnwise, + reference_columnwise_amax, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", + ) @pytest.mark.skipif( not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA." @@ -1174,15 +1115,12 @@ def test_nvfp4_partial_cast_matches_full() -> None: reference_data = reference_tensor._rowwise_data.detach().clone() reference_scale = reference_tensor._rowwise_scale_inv.detach().clone() reference_amax = reference_tensor._amax_rowwise.detach().clone() - print(f"[Rank {WORLD_RANK}] reference_data shape: {reference_data.shape}") - print(f"[Rank {WORLD_RANK}] reference_scale shape: {reference_scale.shape}") # Split master weight evenly across ranks shard_size = total_elements // WORLD_SIZE start_offset = WORLD_RANK * shard_size end_offset = start_offset + shard_size master_weight_shard = full_master_weight.view(-1)[start_offset:end_offset].clone() - print(f"[Rank {WORLD_RANK}] shard: start_offset={start_offset}, end_offset={end_offset}, shard_size={shard_size}") # Create empty NVFP4 tensor for this rank (full shape, but we'll only fill our shard) nvfp4_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) @@ -1216,7 +1154,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: # Reconstruct the full rowwise data gathered_data = torch.cat(gathered_shards, dim=0).view(reference_data.shape) - print(f"[Rank {WORLD_RANK}] gathered_data shape: {gathered_data.shape}") # Compare with reference torch.testing.assert_close( @@ -1226,7 +1163,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: rtol=0, msg=f"[Rank {WORLD_RANK}] Gathered rowwise data does not match reference!", ) - print(f"[Rank {WORLD_RANK}] rowwise_data matches reference!") # Also verify scale matches (scale should be identical on all ranks after all-reduce) torch.testing.assert_close( @@ -1236,7 +1172,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: rtol=0, msg=f"[Rank {WORLD_RANK}] Scale does not match reference!", ) - print(f"[Rank {WORLD_RANK}] scale matches reference!") # Verify amax matches torch.testing.assert_close( @@ -1246,10 +1181,6 @@ def test_nvfp4_partial_cast_matches_full() -> None: rtol=0, msg=f"[Rank {WORLD_RANK}] Amax does not match reference!", ) - print(f"[Rank {WORLD_RANK}] amax matches reference!") - - print(f"[Rank {WORLD_RANK}] Multi-GPU NVFP4 partial cast test PASSED!") - def test_single_gpu_partial_cast_vs_full(): """ @@ -1278,11 +1209,6 @@ def test_single_gpu_partial_cast_vs_full(): ref_scale = ref._rowwise_scale_inv.clone() ref_amax = ref._amax_rowwise.clone() - print(f"Reference:") - print(f" data shape: {ref_data.shape}") - print(f" scale shape: {ref_scale.shape}") - print(f" amax: {ref_amax}") - # === Test: Use cast_master_weights_to_nvfp4 with offset=0 (full tensor) === # Create empty NVFP4 tensor test_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) @@ -1296,7 +1222,6 @@ def test_single_gpu_partial_cast_vs_full(): dist.init_process_group(backend="nccl", init_method="env://", rank=0, world_size=1) mock_group = dist.new_group(ranks=[0]) - # First pass: find mismatch location cast_master_weights_to_nvfp4( [test_tensor], [master_weight.view(-1)], # Flatten as expected @@ -1304,101 +1229,17 @@ def test_single_gpu_partial_cast_vs_full(): mock_group, ) - # Check for mismatches and set debug env vars - scale_diff = (test_tensor._rowwise_scale_inv != ref_scale) - if scale_diff.any(): - diff_idx = torch.nonzero(scale_diff, as_tuple=True) - r, c = diff_idx[0][0].item(), diff_idx[1][0].item() - tile_row = r // 16 - tile_col = c - print(f"\n=== Found mismatch at scale[{r},{c}], tile[{tile_row},{tile_col}] ===") - print(f" Running second pass with debug enabled...") - - # Set env vars for debug - os.environ["NVFP4_DEBUG_TILE_ROW"] = str(tile_row) - os.environ["NVFP4_DEBUG_TILE_COL"] = str(tile_col) - - # Reset tensor and run again with debug - test_tensor._rowwise_data.zero_() - test_tensor._rowwise_scale_inv.zero_() - if test_tensor._amax_rowwise is not None: - test_tensor._amax_rowwise.zero_() - - cast_master_weights_to_nvfp4( - [test_tensor], - [master_weight.view(-1)], - [0], - mock_group, - ) - - # Clear env vars - del os.environ["NVFP4_DEBUG_TILE_ROW"] - del os.environ["NVFP4_DEBUG_TILE_COL"] - - print(f"\nTest (cast_master_weights_to_nvfp4 with offset=0):") - print(f" data shape: {test_tensor._rowwise_data.shape}") - print(f" scale shape: {test_tensor._rowwise_scale_inv.shape}") - print(f" amax: {test_tensor._amax_rowwise}") - - # === Compare === - print(f"\nComparison:") - # Compare amax amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) - print(f" Amax match: {amax_match}") - if not amax_match: - print(f" test: {test_tensor._amax_rowwise}") - print(f" ref: {ref_amax}") # Compare scale scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) - print(f" Scale match: {scale_match}") - if not scale_match: - mismatches = (test_tensor._rowwise_scale_inv != ref_scale).sum().item() - total = ref_scale.numel() - print(f" Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") - - # Find first mismatch location - diff = (test_tensor._rowwise_scale_inv != ref_scale) - if diff.any(): - diff_idx = torch.nonzero(diff, as_tuple=True) - r, c = diff_idx[0][0].item(), diff_idx[1][0].item() - print(f"\n === First mismatch at [{r},{c}] ===") - print(f" test scale (uint8): {test_tensor._rowwise_scale_inv[r,c].item()}") - print(f" ref scale (uint8): {ref_scale[r,c].item()}") - - # Convert to FP32 to see the actual values - test_fp32 = test_tensor._rowwise_scale_inv[r,c].view(torch.float8_e4m3fn).to(torch.float32).item() - ref_fp32 = ref_scale[r,c].view(torch.float8_e4m3fn).to(torch.float32).item() - print(f" test scale (FP32): {test_fp32}") - print(f" ref scale (FP32): {ref_fp32}") - - # Compute which tile this belongs to - tile_row = r // 16 - tile_col = c - print(f" Tile position: [{tile_row}, {tile_col}]") - - # Store this for utils.py to print debug info - import os - os.environ["NVFP4_DEBUG_TILE_ROW"] = str(tile_row) - os.environ["NVFP4_DEBUG_TILE_COL"] = str(tile_col) # Compare data data_match = torch.equal(test_tensor._rowwise_data, ref_data) - print(f" Data match: {data_match}") - if not data_match: - mismatches = (test_tensor._rowwise_data != ref_data).sum().item() - total = ref_data.numel() - print(f" Mismatches: {mismatches}/{total} ({100*mismatches/total:.4f}%)") - - if amax_match and scale_match and data_match: - print("\nSUCCESS: cast_master_weights_to_nvfp4 (offset=0) matches quantizer!") - else: - print("\nFAILURE: Results don't match!") - if __name__ == "__main__": - #main() - test_nvfp4_transpose_kernel() + main() + #test_nvfp4_transpose_kernel() #test_single_gpu_partial_cast_vs_full() #test_nvfp4_partial_cast_matches_full() \ No newline at end of file diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu old mode 100644 new mode 100755 index 84c0460870..9fbdf585ff --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -759,7 +759,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); NVTE_CHECK(new_workspace_alignment % 256 == 0, "cuBLAS workspace pointer must be aligned to 256 bytes, got ", - new_workspace_alignment); + new_workspace_alignment); + const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 8504e28843..5a89109d5a 100755 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -309,12 +309,45 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, cudaStream_t stream); -// NVFP4 2D (16x16) partial-shard APIs +/*! \brief Compute tile-level amax for a partial shard of a 2D tensor. + * + * For NVFP4 2D quantization with 16x16 tiles. Computes the maximum absolute + * value within each tile, but only for elements in [start_offset, start_offset + len) + * of the flattened tensor. Used in distributed settings where each rank owns a shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] amax Output amax buffer [tile_rows, tile_cols], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] amax_stride_h Stride for amax in tile-row dimension. + * \param[in] amax_stride_w Stride for amax in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, size_t block_len, cudaStream_t stream); +/*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization. + * + * Quantizes elements in [start_offset, start_offset + len) of the flattened tensor + * using precomputed per-tile scales. Each 16x16 tile uses its own scale factor. + * Used in distributed settings where each rank casts its owned shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] out Output NVFP4 packed tensor (2 values per byte). + * \param[in] scale Per-tile scale factors [tile_rows, tile_cols], float32. + * \param[in] global_scale Global scale factor [1], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] scale_stride_h Stride for scale in tile-row dimension. + * \param[in] scale_stride_w Stride for scale in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, const NVTETensor global_scale, size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, @@ -346,7 +379,7 @@ void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_ void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, size_t K_tiles, cudaStream_t stream); -/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3. +/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast. * * Each tile row's scale is repeated block_len times in the output. * diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 65c909fc53..2fb0310301 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -60,11 +60,6 @@ namespace nvfp4_recipe { * Warp 1 -> rows 2-3 (groups of 4 elements per thread) * ... * Warp 7 -> rows 14-15 - * - * The host helper `_cast_master_weights_to_nvfp4_2d` reduces per-tile amax - * values, packs the resulting FP32 scales into the uint8 `_rowwise_scale_inv`, - * and launches `tex.nvfp4_2d_partial_cast`. The resulting bytes match TE’s full - * NVFP4 quantizer, so downstream GEMMs/checkpoints remain unchanged. * --------------------------------------------------------------------------- */ @@ -337,12 +332,6 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, * - Before transpose: elements [m, 2c] and [m, 2c+1] share a byte * - After transpose: elements [k, 2*m_packed] and [k, 2*m_packed+1] share a byte * which were originally [2*m_packed, k] and [2*m_packed+1, k] - * - * Two implementations: - * 1. TMA version (SM100+/Hopper): Uses bulk async copy for ~3x bandwidth - * 2. Vectorized version (fallback): Uses uint2 loads/stores - * - * Tile size: 64x64 logical FP4 elements = 64x32 input bytes = 64x32 output bytes * --------------------------------------------------------------------------- */ @@ -579,7 +568,7 @@ void nvfp4_scale_transpose(const Tensor input, Tensor output, * --------------------------------------------------------------------------- * NVFP4 SCALE EXPANSION KERNEL * - * Expands tile-level scales to row-level scales and converts to FP8 E4M3. + * Expands tile-level scales to row-level scales and converts to FP8 E4M3, used in partial cast. * * Input (per_block_decode_scale): [tile_rows, tile_cols] in float32 * Output (target_scale): [rows_padded, tile_cols] in uint8 (E4M3) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py old mode 100644 new mode 100755 index 941b1240af..875d245a8f --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1268,6 +1268,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) + # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 25a138b8db..87e9c91780 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -356,10 +356,7 @@ def _create_columnwise(self): # NVFP4 requires a specialized transpose that handles nibble repacking self._columnwise_data = tex.nvfp4_transpose(rowwise_data, out=self._columnwise_data) if self._columnwise_scale_inv is None: - assert self._quantizer is not None, ( - "._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise " - "quantized tensors are supposed to be generated from the quantizer." - ) + assert self._quantizer is not None, # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 logical_shape = self.size() @@ -376,8 +373,6 @@ def _create_columnwise(self): # is repeated 16 times (once per row in the 16x16 tile). # columnwise_scale_inv has shape [K_padded, M_tiles] where scales are # repeated 16 times per tile row. - # - # Use GPU kernel to efficiently transpose and expand the scales. TILE_SIZE = 16 logical_shape = self.size() M, K = logical_shape[0], logical_shape[-1] diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 834406cae5..09ff61c8db 100755 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -246,7 +246,7 @@ def _cast_master_weights_to_fp8_delayed_scaling( for model_weight, master_weight, start_offset, shard_model_weight_raw in params: if not manual_post_all_gather_processing: - # Reset transpose cache for all model weights. + # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. @@ -392,7 +392,7 @@ def _cast_master_weights_to_fp8_current_scaling( params, scales ): if not manual_post_all_gather_processing: - # Reset transpose cache for all model weights. + # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated currently. @@ -523,7 +523,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( params, scales ): if not manual_post_all_gather_processing: - # Reset transpose cache for all model weights. + # Clear columnwise data for all model weights. # We cannot create columnwise data here because users (like megatron) may want to # overlap the all-gather of model weights and forward process, so the model weight is # not updated at this moment. @@ -544,7 +544,6 @@ def _cast_master_weights_to_fp8_blockwise_scaling( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) -# revisit this later def _cast_master_weights_to_nvfp4_2d( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): @@ -713,17 +712,6 @@ def _cast_master_weights_to_nvfp4_2d( fused_scale_tile_rows_list.append(tile_rows) fused_scale_tile_cols_list.append(tile_col_cnt) fused_scale_rows_padded_list.append(rows_padded) - else: - # Fallback: compute scale and expand without amax copy (rare case) - tex.nvfp4_compute_per_block_scale(block_amax, per_block_decode_scale, global_amax_view) - tex.nvfp4_expand_scale_to_fp8( - per_block_decode_scale, - target_scale, - tile_rows, - tile_col_cnt, - rows_padded, - block_len, - ) # Collect for partial cast kernel (only for layers owned by this rank) if master_weight is not None and master_weight.numel() > 0: From 36b88026e462d4756d24fde990222f2c13684536 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 27 Jan 2026 18:30:14 +0000 Subject: [PATCH 35/40] fix syntax error --- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 87e9c91780..84fff4cf2e 100755 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -356,7 +356,7 @@ def _create_columnwise(self): # NVFP4 requires a specialized transpose that handles nibble repacking self._columnwise_data = tex.nvfp4_transpose(rowwise_data, out=self._columnwise_data) if self._columnwise_scale_inv is None: - assert self._quantizer is not None, + assert self._quantizer is not None # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 logical_shape = self.size() From 3c4adf44fe712b56183e24d60e2dfac6edff6488 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 28 Jan 2026 00:51:29 +0000 Subject: [PATCH 36/40] bf16 to fp4 quant --- transformer_engine/common/recipe/nvfp4.cu | 63 +++++++++++++---------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 2fb0310301..d782612c9d 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -40,9 +40,10 @@ namespace nvfp4_recipe { * +------------------+ * * 2) Partial Cast (`nvfp4_2d_partial_cast_kernel`) - * - Stage the tile into shared memory (same pattern as FP8). - * - For each 4-value group, build float2 pairs and call - * `ptx::mul_cvt_fp32_to_fp4_4x`, producing packed FP4 nibbles. + * - Stage the BF16 tile into shared memory (same pattern as FP8). + * - For each 4-value group, pack BF16 values and call + * `ptx::mul_cvt_bf16_to_fp4_4x`, producing packed FP4 nibbles. + * - Uses direct BF16->FP4 conversion to match quantize_transpose kernel. * - Compute a shard-local byte index and update only the owned nibble(s) * using read-modify-write: * @@ -124,9 +125,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } -template +template __global__ void __launch_bounds__(kThreadsPerBlock) - nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *decode_scale_ptr, + nvfp4_2d_partial_cast_kernel(const bf16 *input, uint8_t *output, const float *decode_scale_ptr, const size_t scale_stride_h, const size_t scale_stride_w, const float *global_scale_ptr, const size_t h, const size_t w, const size_t start_offset, const size_t len) { @@ -136,12 +137,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; constexpr int kRowsPerWarp = (kTileDim + kNumWarps - 1) / kNumWarps; - __shared__ float smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + // Use bf16 for shared memory to enable direct BF16->FP4 conversion. + // This matches the quantize_transpose kernel's NO_ACTIVATIONS_NOT_FP32_INPUT path. + __shared__ bf16 smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; const int tile_w = blockIdx.x; const int tile_h = blockIdx.y; const size_t shard_end = start_offset + len; - const IType *input_minus_offset = input - start_offset; + const bf16 *input_minus_offset = input - start_offset; float global_encode_scale = global_scale_ptr[0]; if (global_encode_scale <= 0.f) { @@ -172,7 +175,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && idx_in_input < shard_end) { - smem[h_in_smem][w_in_smem] = static_cast(input_minus_offset[idx_in_input]); + // Store input directly without casting to float + smem[h_in_smem][w_in_smem] = input_minus_offset[idx_in_input]; skip_store = false; } } @@ -199,7 +203,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } const int col_in_output = tile_w * kTileDim + col_in_smem; - float vals[kNumOutputElemsPerBank]; + bf16 vals[kNumOutputElemsPerBank]; bool mask[kNumOutputElemsPerBank]; size_t elem_idx[kNumOutputElemsPerBank]; bool any_valid = false; @@ -212,9 +216,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool in_shard = in_width && idx >= start_offset && idx < shard_end; mask[j] = in_shard; const bool in_tile = (col_in_smem + j) < kTileDim; - const float tile_val = - in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; - vals[j] = in_shard ? tile_val : 0.0f; + const bf16 zero_bf16 = __float2bfloat16(0.0f); + const bf16 tile_val = in_tile ? smem[row_in_smem][col_in_smem + j] : zero_bf16; + vals[j] = in_shard ? tile_val : zero_bf16; any_valid |= in_shard; } @@ -222,10 +226,17 @@ __global__ void __launch_bounds__(kThreadsPerBlock) continue; } - const float2 in01 = make_float2(vals[0], vals[1]); - const float2 in23 = make_float2(vals[2], vals[3]); + // Use direct BF16->FP4 conversion to match quantize_transpose kernel. + // This avoids the BF16->FP32->FP4 round-trip that causes numerical differences. + // Pack 4 BF16 values into uint64_t for mul_cvt_bf16_to_fp4_4x + uint64_t in_4x; + uint16_t *in_ptr = reinterpret_cast(&in_4x); + in_ptr[0] = reinterpret_cast(vals[0]); + in_ptr[1] = reinterpret_cast(vals[1]); + in_ptr[2] = reinterpret_cast(vals[2]); + in_ptr[3] = reinterpret_cast(vals[3]); const auto packed = - transformer_engine::ptx::mul_cvt_fp32_to_fp4_4x(in01, in23, scale_vec, 0); + transformer_engine::ptx::mul_cvt_bf16_to_fp4_4x(in_4x, scale_vec, 0); const uint16_t packed_bits = reinterpret_cast(packed); for (int pair = 0; pair < 2; ++pair) { @@ -296,6 +307,8 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, cudaStream_t stream) { NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); + NVTE_CHECK(inp.dtype() == DType::kBFloat16, + "NVFP4 partial cast requires BF16 input for direct BF16->FP4 conversion."); size_t len = inp.numel(); @@ -309,17 +322,15 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, assert(blocks_y <= std::numeric_limits::max()); dim3 grid(blocks_x, blocks_y); - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - inp.dtype(), inp_dtype, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - w % kTileDim == 0, kWidthAligned, - nvfp4_2d_partial_cast_kernel - <<>>( - reinterpret_cast(inp.data.dptr), - reinterpret_cast(out.data.dptr), - reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, - reinterpret_cast(global_scale.data.dptr), h, w, start_offset, - len);)) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, + len);) NVTE_CHECK_CUDA(cudaGetLastError()); } From ed65cdf1e55fba07bb1e66c7b2a0c2948d23779a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 19 Feb 2026 18:28:45 +0000 Subject: [PATCH 37/40] Revert "bf16 to fp4 quant" This reverts commit 3c4adf44fe712b56183e24d60e2dfac6edff6488. --- transformer_engine/common/recipe/nvfp4.cu | 63 ++++++++++------------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index d782612c9d..2fb0310301 100755 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -40,10 +40,9 @@ namespace nvfp4_recipe { * +------------------+ * * 2) Partial Cast (`nvfp4_2d_partial_cast_kernel`) - * - Stage the BF16 tile into shared memory (same pattern as FP8). - * - For each 4-value group, pack BF16 values and call - * `ptx::mul_cvt_bf16_to_fp4_4x`, producing packed FP4 nibbles. - * - Uses direct BF16->FP4 conversion to match quantize_transpose kernel. + * - Stage the tile into shared memory (same pattern as FP8). + * - For each 4-value group, build float2 pairs and call + * `ptx::mul_cvt_fp32_to_fp4_4x`, producing packed FP4 nibbles. * - Compute a shard-local byte index and update only the owned nibble(s) * using read-modify-write: * @@ -125,9 +124,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } -template +template __global__ void __launch_bounds__(kThreadsPerBlock) - nvfp4_2d_partial_cast_kernel(const bf16 *input, uint8_t *output, const float *decode_scale_ptr, + nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *decode_scale_ptr, const size_t scale_stride_h, const size_t scale_stride_w, const float *global_scale_ptr, const size_t h, const size_t w, const size_t start_offset, const size_t len) { @@ -137,14 +136,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; constexpr int kRowsPerWarp = (kTileDim + kNumWarps - 1) / kNumWarps; - // Use bf16 for shared memory to enable direct BF16->FP4 conversion. - // This matches the quantize_transpose kernel's NO_ACTIVATIONS_NOT_FP32_INPUT path. - __shared__ bf16 smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + __shared__ float smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; const int tile_w = blockIdx.x; const int tile_h = blockIdx.y; const size_t shard_end = start_offset + len; - const bf16 *input_minus_offset = input - start_offset; + const IType *input_minus_offset = input - start_offset; float global_encode_scale = global_scale_ptr[0]; if (global_encode_scale <= 0.f) { @@ -175,8 +172,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && idx_in_input < shard_end) { - // Store input directly without casting to float - smem[h_in_smem][w_in_smem] = input_minus_offset[idx_in_input]; + smem[h_in_smem][w_in_smem] = static_cast(input_minus_offset[idx_in_input]); skip_store = false; } } @@ -203,7 +199,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } const int col_in_output = tile_w * kTileDim + col_in_smem; - bf16 vals[kNumOutputElemsPerBank]; + float vals[kNumOutputElemsPerBank]; bool mask[kNumOutputElemsPerBank]; size_t elem_idx[kNumOutputElemsPerBank]; bool any_valid = false; @@ -216,9 +212,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool in_shard = in_width && idx >= start_offset && idx < shard_end; mask[j] = in_shard; const bool in_tile = (col_in_smem + j) < kTileDim; - const bf16 zero_bf16 = __float2bfloat16(0.0f); - const bf16 tile_val = in_tile ? smem[row_in_smem][col_in_smem + j] : zero_bf16; - vals[j] = in_shard ? tile_val : zero_bf16; + const float tile_val = + in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; + vals[j] = in_shard ? tile_val : 0.0f; any_valid |= in_shard; } @@ -226,17 +222,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) continue; } - // Use direct BF16->FP4 conversion to match quantize_transpose kernel. - // This avoids the BF16->FP32->FP4 round-trip that causes numerical differences. - // Pack 4 BF16 values into uint64_t for mul_cvt_bf16_to_fp4_4x - uint64_t in_4x; - uint16_t *in_ptr = reinterpret_cast(&in_4x); - in_ptr[0] = reinterpret_cast(vals[0]); - in_ptr[1] = reinterpret_cast(vals[1]); - in_ptr[2] = reinterpret_cast(vals[2]); - in_ptr[3] = reinterpret_cast(vals[3]); + const float2 in01 = make_float2(vals[0], vals[1]); + const float2 in23 = make_float2(vals[2], vals[3]); const auto packed = - transformer_engine::ptx::mul_cvt_bf16_to_fp4_4x(in_4x, scale_vec, 0); + transformer_engine::ptx::mul_cvt_fp32_to_fp4_4x(in01, in23, scale_vec, 0); const uint16_t packed_bits = reinterpret_cast(packed); for (int pair = 0; pair < 2; ++pair) { @@ -307,8 +296,6 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, cudaStream_t stream) { NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); - NVTE_CHECK(inp.dtype() == DType::kBFloat16, - "NVFP4 partial cast requires BF16 input for direct BF16->FP4 conversion."); size_t len = inp.numel(); @@ -322,15 +309,17 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, assert(blocks_y <= std::numeric_limits::max()); dim3 grid(blocks_x, blocks_y); - TRANSFORMER_ENGINE_SWITCH_CONDITION( - w % kTileDim == 0, kWidthAligned, - nvfp4_2d_partial_cast_kernel - <<>>( - reinterpret_cast(inp.data.dptr), - reinterpret_cast(out.data.dptr), - reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, - reinterpret_cast(global_scale.data.dptr), h, w, start_offset, - len);) + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, + len);)) NVTE_CHECK_CUDA(cudaGetLastError()); } From f9d4e890ccdc52e3d5189d46dc592fa2798c1bd5 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 19 Feb 2026 18:48:54 +0000 Subject: [PATCH 38/40] file permission and minor --- .../distributed/test_cast_master_weights_to_fp8.py | 10 +++++----- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- .../common/include/transformer_engine/recipe.h | 0 transformer_engine/common/recipe/nvfp4.cu | 0 transformer_engine/pytorch/csrc/extensions.h | 0 .../pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp | 0 transformer_engine/pytorch/csrc/extensions/pybind.cpp | 0 transformer_engine/pytorch/csrc/extensions/recipe.cpp | 0 .../pytorch/csrc/extensions/transpose.cpp | 0 transformer_engine/pytorch/module/base.py | 0 .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 0 transformer_engine/pytorch/tensor/utils.py | 0 12 files changed, 6 insertions(+), 6 deletions(-) mode change 100755 => 100644 tests/pytorch/distributed/test_cast_master_weights_to_fp8.py mode change 100755 => 100644 transformer_engine/common/gemm/cublaslt_gemm.cu mode change 100755 => 100644 transformer_engine/common/include/transformer_engine/recipe.h mode change 100755 => 100644 transformer_engine/common/recipe/nvfp4.cu mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions.h mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/pybind.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/recipe.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/transpose.cpp mode change 100755 => 100644 transformer_engine/pytorch/module/base.py mode change 100755 => 100644 transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py mode change 100755 => 100644 transformer_engine/pytorch/tensor/utils.py diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py old mode 100755 new mode 100644 index 257ff9f0c8..49ceff4ad8 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -853,8 +853,8 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] mock_group = mock_groups[rank] @@ -1098,7 +1098,7 @@ def test_nvfp4_partial_cast_matches_full() -> None: if not available: pytest.skip(reason) - torch.manual_seed(77777) + torch.manual_seed(1234) device = torch.device("cuda") # Shape must be divisible by WORLD_SIZE for even splitting # Also ensure dimensions are multiples of 16 for NVFP4 tiles @@ -1193,7 +1193,7 @@ def test_single_gpu_partial_cast_vs_full(): from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_nvfp4 import transformer_engine_torch as tex - torch.manual_seed(77777) + torch.manual_seed(1234) device = torch.device("cuda") # Test with same shape as the optimizer test @@ -1242,4 +1242,4 @@ def test_single_gpu_partial_cast_vs_full(): main() #test_nvfp4_transpose_kernel() #test_single_gpu_partial_cast_vs_full() - #test_nvfp4_partial_cast_matches_full() \ No newline at end of file + #test_nvfp4_partial_cast_matches_full() diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu old mode 100755 new mode 100644 index 0344052f8e..c58c3cb47a --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -757,7 +757,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(new_workspace_alignment % 256 == 0, "cuBLAS workspace pointer must be aligned to 256 bytes, got ", new_workspace_alignment); - + const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h old mode 100755 new mode 100644 diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py old mode 100755 new mode 100644 From 371c229185b1fa0c1a6c76921c04eb8eb857e2ce Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 19 Feb 2026 18:53:41 +0000 Subject: [PATCH 39/40] remove prints --- tests/pytorch/distributed/test_cast_master_weights_to_fp8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 49ceff4ad8..75ed2f50a0 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -891,7 +891,7 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi ) optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) - for i in range(100): + for i in range(500): for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): w_nvfp4.main_grad.zero_() w.main_grad.zero_() @@ -927,7 +927,6 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi optimizer_nvfp4.step() torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) - print("iter:", i, "loss matched") def run_parallel_tests() -> None: """Run parallel tests""" From 687c8b65d3ef0941119deacf906a422a6e908fe8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Feb 2026 18:58:33 +0000 Subject: [PATCH 40/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_cast_master_weights_to_fp8.py | 79 ++- .../include/transformer_engine/recipe.h | 20 +- transformer_engine/common/recipe/nvfp4.cu | 622 +++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 65 +- .../csrc/extensions/nvfp4_2d_partial_cast.cpp | 70 +- .../pytorch/csrc/extensions/pybind.cpp | 60 +- .../pytorch/csrc/extensions/transpose.cpp | 106 ++- .../tensor/storage/nvfp4_tensor_storage.py | 8 +- transformer_engine/pytorch/tensor/utils.py | 63 +- 9 files changed, 513 insertions(+), 580 deletions(-) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 75ed2f50a0..89561e5259 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -225,7 +225,10 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals for idx, (weight, storage_offset, storage_size) in enumerate( zip(self.weights, self.storage_offsets[:-1], self.storage_sizes) ): - if storage_offset >= storage_rank_end or (storage_offset + storage_size) <= storage_rank_start: + if ( + storage_offset >= storage_rank_end + or (storage_offset + storage_size) <= storage_rank_start + ): continue overlap_start = max(storage_rank_start, storage_offset) overlap_end = min(storage_rank_end, storage_offset + storage_size) @@ -867,14 +870,14 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True ): model_nvfp4 = nn.Sequential( - te.Linear(128, 256+64, **linear_kwargs), - te.Linear(256+64, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) # Create model with bf16 weights model = nn.Sequential( - te.Linear(128, 256+64, **linear_kwargs), - te.Linear(256+64, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) @@ -925,9 +928,10 @@ def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processi optimizer.step() optimizer_nvfp4.step() - + torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) + def run_parallel_tests() -> None: """Run parallel tests""" @@ -1005,9 +1009,7 @@ def main() -> None: run_parallel_tests() -@pytest.mark.skipif( - not torch.cuda.is_available(), reason="NVFP4 transpose test requires CUDA." -) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4 transpose test requires CUDA.") def test_nvfp4_transpose_kernel() -> None: """Test that nvfp4_transpose kernel produces bitwise identical results to reference.""" available, reason = is_nvfp4_available(return_reason=True) @@ -1027,10 +1029,16 @@ def test_nvfp4_transpose_kernel() -> None: ) reference_tensor = quantizer_with_colwise(master_weight.to(torch.bfloat16)) assert reference_tensor._columnwise_data is not None, "Reference should have columnwise data" - assert reference_tensor._columnwise_scale_inv is not None, "Reference should have columnwise scale_inv" + assert ( + reference_tensor._columnwise_scale_inv is not None + ), "Reference should have columnwise scale_inv" reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() reference_columnwise_scale_inv = reference_tensor._columnwise_scale_inv.detach().clone() - reference_columnwise_amax = reference_tensor._amax_columnwise.detach().clone() if reference_tensor._amax_columnwise is not None else None + reference_columnwise_amax = ( + reference_tensor._amax_columnwise.detach().clone() + if reference_tensor._amax_columnwise is not None + else None + ) # Create tensor with only rowwise data, then call _create_columnwise() quantizer_rowwise_only = NVFP4Quantizer( @@ -1041,8 +1049,12 @@ def test_nvfp4_transpose_kernel() -> None: # Now call _create_columnwise() which uses our nvfp4_transpose kernel test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) - assert test_tensor._columnwise_data is not None, "Test tensor should have columnwise data after _create_columnwise()" - assert test_tensor._columnwise_scale_inv is not None, "Test tensor should have columnwise scale_inv after _create_columnwise()" + assert ( + test_tensor._columnwise_data is not None + ), "Test tensor should have columnwise data after _create_columnwise()" + assert ( + test_tensor._columnwise_scale_inv is not None + ), "Test tensor should have columnwise scale_inv after _create_columnwise()" # Compare columnwise data - should be bitwise identical torch.testing.assert_close( @@ -1052,7 +1064,7 @@ def test_nvfp4_transpose_kernel() -> None: rtol=0, msg="NVFP4 transpose kernel produced different columnwise data than reference!", ) - + torch.testing.assert_close( test_tensor._columnwise_scale_inv, reference_columnwise_scale_inv, @@ -1069,9 +1081,8 @@ def test_nvfp4_transpose_kernel() -> None: msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", ) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA." -) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA.") def test_nvfp4_partial_cast_matches_full() -> None: """Test multi-GPU partial cast: split master weight, partial cast on each rank, all-gather, compare.""" WORLD_RANK = int(os.getenv("RANK", "0")) @@ -1140,17 +1151,17 @@ def test_nvfp4_partial_cast_matches_full() -> None: # Each rank has the full tensor but only its shard is filled # We need to all-gather the shards rowwise_data_flat = nvfp4_tensor._rowwise_data.view(-1) - + # For NVFP4, 2 elements are packed per byte, so byte shard size is shard_size // 2 byte_shard_size = shard_size // 2 byte_start = WORLD_RANK * byte_shard_size byte_end = byte_start + byte_shard_size my_shard_bytes = rowwise_data_flat[byte_start:byte_end].contiguous() - + # Gather all shards gathered_shards = [torch.empty_like(my_shard_bytes) for _ in range(WORLD_SIZE)] dist.all_gather(gathered_shards, my_shard_bytes, group=dp_group) - + # Reconstruct the full rowwise data gathered_data = torch.cat(gathered_shards, dim=0).view(reference_data.shape) @@ -1181,6 +1192,7 @@ def test_nvfp4_partial_cast_matches_full() -> None: msg=f"[Rank {WORLD_RANK}] Amax does not match reference!", ) + def test_single_gpu_partial_cast_vs_full(): """ Single GPU test: compare cast_master_weights_to_nvfp4 (offset=0) vs quantizer(). @@ -1191,23 +1203,23 @@ def test_single_gpu_partial_cast_vs_full(): from transformer_engine.pytorch.tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_nvfp4 import transformer_engine_torch as tex - + torch.manual_seed(1234) device = torch.device("cuda") - + # Test with same shape as the optimizer test shape = (2048, 2048) - + # Create BF16 master weight master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) - + # === Reference: Use NVFP4Quantizer directly === quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) ref = quantizer(master_weight) ref_data = ref._rowwise_data.clone() ref_scale = ref._rowwise_scale_inv.clone() ref_amax = ref._amax_rowwise.clone() - + # === Test: Use cast_master_weights_to_nvfp4 with offset=0 (full tensor) === # Create empty NVFP4 tensor test_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) @@ -1215,30 +1227,31 @@ def test_single_gpu_partial_cast_vs_full(): test_tensor._rowwise_scale_inv.zero_() if test_tensor._amax_rowwise is not None: test_tensor._amax_rowwise.zero_() - + # Create a mock distributed group for single GPU if not dist.is_initialized(): dist.init_process_group(backend="nccl", init_method="env://", rank=0, world_size=1) mock_group = dist.new_group(ranks=[0]) - + cast_master_weights_to_nvfp4( [test_tensor], [master_weight.view(-1)], # Flatten as expected [0], # offset=0 means full tensor mock_group, ) - + # Compare amax amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) - + # Compare scale scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) - + # Compare data data_match = torch.equal(test_tensor._rowwise_data, ref_data) + if __name__ == "__main__": main() - #test_nvfp4_transpose_kernel() - #test_single_gpu_partial_cast_vs_full() - #test_nvfp4_partial_cast_matches_full() + # test_nvfp4_transpose_kernel() + # test_single_gpu_partial_cast_vs_full() + # test_nvfp4_partial_cast_matches_full() diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 5a89109d5a..8d0abe5599 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -327,8 +327,7 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r */ void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, size_t amax_stride_h, size_t amax_stride_w, - size_t start_offset, size_t block_len, - cudaStream_t stream); + size_t start_offset, size_t block_len, cudaStream_t stream); /*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization. * @@ -376,8 +375,8 @@ void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_ * \param[in] K_tiles Number of tiles in K dimension. * \param[in] stream CUDA stream. */ -void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, - size_t M_tiles, size_t K_tiles, cudaStream_t stream); +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream); /*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast. * @@ -391,14 +390,13 @@ void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, * \param[in] block_len Block length (typically 16 for NVFP4). * \param[in] stream CUDA stream. */ -void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, cudaStream_t stream); /*! \brief Compute per-block decode scale from block amax and global amax. * - * Computes: + * Computes: * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax * per_block_decode_scale = block_amax / fp4_max * global_scale * @@ -434,10 +432,8 @@ void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor */ void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, NVTETensor per_block_scale, NVTETensor target_scale, - NVTETensor target_amax, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, - cudaStream_t stream); + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream); /*! \brief Compute global encode scale from global amax. * diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 2fb0310301..37d7830147 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -10,8 +10,8 @@ #include #include "../common.h" -#include "../utils.cuh" #include "../util/ptx.cuh" +#include "../utils.cuh" namespace transformer_engine { namespace nvfp4_recipe { @@ -149,13 +149,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } const float global_decode_scale = 1.0f / global_encode_scale; - float tile_decode_scale = - decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + float tile_decode_scale = decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; tile_decode_scale = static_cast(static_cast(tile_decode_scale)); constexpr float kFp32Max = 3.402823466e+38F; - float tile_encode_val = (tile_decode_scale > 0.f) - ? 1.0f / (tile_decode_scale * global_decode_scale) - : kFp32Max; + float tile_encode_val = + (tile_decode_scale > 0.f) ? 1.0f / (tile_decode_scale * global_decode_scale) : kFp32Max; tile_encode_val = fminf(tile_encode_val, kFp32Max); const float2 scale_vec = make_float2(tile_encode_val, tile_encode_val); @@ -212,8 +210,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool in_shard = in_width && idx >= start_offset && idx < shard_end; mask[j] = in_shard; const bool in_tile = (col_in_smem + j) < kTileDim; - const float tile_val = - in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; + const float tile_val = in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; vals[j] = in_shard ? tile_val : 0.0f; any_valid |= in_shard; } @@ -239,8 +236,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) uint8_t byte = output[byte_idx]; if (mask[first]) { - const uint8_t nibble = - static_cast((packed_bits >> (4 * first)) & 0xF); + const uint8_t nibble = static_cast((packed_bits >> (4 * first)) & 0xF); if ((elem_idx[first] & 1u) == 0) { byte = static_cast((byte & 0xF0u) | nibble); } else { @@ -249,8 +245,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } if (mask[second]) { - const uint8_t nibble = - static_cast((packed_bits >> (4 * second)) & 0xF); + const uint8_t nibble = static_cast((packed_bits >> (4 * second)) & 0xF); if ((elem_idx[second] & 1u) == 0) { byte = static_cast((byte & 0xF0u) | nibble); } else { @@ -264,8 +259,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, - size_t amax_stride_h, size_t amax_stride_w, - size_t start_offset, size_t block_len, cudaStream_t stream) { + size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); size_t len = inp.numel(); @@ -282,11 +277,10 @@ void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( inp.dtype(), inp_dtype, - nvfp4_2d_compute_partial_amax_kernel - <<>>(reinterpret_cast(inp.data.dptr), - reinterpret_cast(amax.data.dptr), - amax_stride_h, amax_stride_w, h, w, start_offset, - len);) + nvfp4_2d_compute_partial_amax_kernel<<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, + start_offset, len);) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -314,12 +308,11 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, TRANSFORMER_ENGINE_SWITCH_CONDITION( w % kTileDim == 0, kWidthAligned, nvfp4_2d_partial_cast_kernel - <<>>( - reinterpret_cast(inp.data.dptr), - reinterpret_cast(out.data.dptr), - reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, - reinterpret_cast(global_scale.data.dptr), h, w, start_offset, - len);)) + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, len);)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -336,9 +329,9 @@ void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, */ // Vectorized transpose kernel parameters -constexpr int TRANSPOSE_TILE_DIM = 64; // Logical FP4 elements per tile dimension -constexpr int TRANSPOSE_TILE_PACKED = 32; // TILE_DIM / 2 bytes -constexpr int TRANSPOSE_BLOCK_SIZE = 256; // threads per block +constexpr int TRANSPOSE_TILE_DIM = 64; // Logical FP4 elements per tile dimension +constexpr int TRANSPOSE_TILE_PACKED = 32; // TILE_DIM / 2 bytes +constexpr int TRANSPOSE_BLOCK_SIZE = 256; // threads per block // Shared memory: store unpacked 4-bit values as bytes for easy transpose // Size: TILE_DIM x (TILE_DIM + 4) to avoid bank conflicts @@ -349,9 +342,8 @@ constexpr int TRANSPOSE_SHMEM_STRIDE = TRANSPOSE_TILE_DIM + 4; * Tile: 64x64 logical FP4 = 64x32 packed bytes */ __global__ void __launch_bounds__(TRANSPOSE_BLOCK_SIZE) -nvfp4_transpose_kernel(const uint8_t* __restrict__ input, - uint8_t* __restrict__ output, - const size_t M, const size_t K) { + nvfp4_transpose_kernel(const uint8_t *__restrict__ input, uint8_t *__restrict__ output, + const size_t M, const size_t K) { const size_t K_packed = K / 2; const size_t M_packed = M / 2; @@ -361,34 +353,34 @@ nvfp4_transpose_kernel(const uint8_t* __restrict__ input, __shared__ uint8_t shmem[TRANSPOSE_TILE_DIM][TRANSPOSE_SHMEM_STRIDE]; const int tid = threadIdx.x; - + // Phase 1: Load input tile with VECTORIZED uint2 reads // 256 threads, each loads 8 bytes (uint2) = 2048 bytes total // Input tile: [64 rows, 32 cols] = 2048 bytes { - const int thread_row = tid / 4; // 64 rows, 4 threads per row - const int thread_col = (tid % 4) * 8; // 4 x 8 = 32 bytes per row - + const int thread_row = tid / 4; // 64 rows, 4 threads per row + const int thread_col = (tid % 4) * 8; // 4 x 8 = 32 bytes per row + const size_t global_m = tile_m_start + thread_row; const size_t global_k_packed_base = tile_k_start / 2 + thread_col; - + // Load 8 bytes as uint2 uint2 loaded = make_uint2(0, 0); if (global_m < M && global_k_packed_base + 7 < K_packed) { - loaded = *reinterpret_cast(&input[global_m * K_packed + global_k_packed_base]); + loaded = *reinterpret_cast(&input[global_m * K_packed + global_k_packed_base]); } else if (global_m < M) { // Boundary: scalar loads - uint8_t* bytes = reinterpret_cast(&loaded); - #pragma unroll + uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll for (int b = 0; b < 8; ++b) { size_t col = global_k_packed_base + b; bytes[b] = (col < K_packed) ? input[global_m * K_packed + col] : 0; } } - + // Unpack 8 bytes -> 16 nibbles and store to shared memory - const uint8_t* bytes = reinterpret_cast(&loaded); - #pragma unroll + const uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll for (int b = 0; b < 8; ++b) { const int k0 = thread_col * 2 + b * 2; const int k1 = k0 + 1; @@ -402,42 +394,42 @@ nvfp4_transpose_kernel(const uint8_t* __restrict__ input, // Phase 2: Write output with VECTORIZED uint2 stores // Output tile: [64 rows, 32 cols] = 2048 bytes { - const int thread_row = tid / 4; // output K dimension [0, 64) - const int thread_col_base = (tid % 4) * 8; // output M_packed [0, 32) in steps of 8 - + const int thread_row = tid / 4; // output K dimension [0, 64) + const int thread_col_base = (tid % 4) * 8; // output M_packed [0, 32) in steps of 8 + const size_t global_k = tile_k_start + thread_row; const size_t global_m_packed_base = tile_m_start / 2 + thread_col_base; - + if (global_k >= K) return; - + // Build 8 output bytes in registers uint8_t out_bytes[8]; - - #pragma unroll + +#pragma unroll for (int b = 0; b < 8; ++b) { const int out_m_packed = thread_col_base + b; - + if (global_m_packed_base + b >= M_packed) { out_bytes[b] = 0; continue; } - + // Two M positions that pack into this output byte const int m0 = out_m_packed * 2; const int m1 = out_m_packed * 2 + 1; const int k = thread_row; - + // Read from shared memory (transposed access) const uint8_t val0 = shmem[m0][k]; const uint8_t val1 = shmem[m1][k]; - + out_bytes[b] = val0 | (val1 << 4); } - + // Vectorized store as uint2 if (global_m_packed_base + 7 < M_packed) { - *reinterpret_cast(&output[global_k * M_packed + global_m_packed_base]) = - *reinterpret_cast(out_bytes); + *reinterpret_cast(&output[global_k * M_packed + global_m_packed_base]) = + *reinterpret_cast(out_bytes); } else { // Boundary: scalar stores for (int b = 0; b < 8 && global_m_packed_base + b < M_packed; ++b) { @@ -457,7 +449,8 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { // Get dimensions from packed storage // input.shape() = [M, K/2], so M = shape[0], K = shape[1] * 2 const auto in_shape = input.shape(); - NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", in_shape.size(), "D."); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", + in_shape.size(), "D."); const size_t M = in_shape[0]; const size_t K_packed = in_shape[1]; const size_t K = K_packed * 2; @@ -469,8 +462,8 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { const auto out_shape = output.shape(); NVTE_CHECK(out_shape.size() == 2, "NVFP4 transpose expects 2D output."); NVTE_CHECK(out_shape[0] == K && out_shape[1] == M_packed, - "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, - "], got [", out_shape[0], ", ", out_shape[1], "]."); + "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, "], got [", + out_shape[0], ", ", out_shape[1], "]."); if (M == 0 || K == 0) return; @@ -479,11 +472,11 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { dim3 block(TRANSPOSE_BLOCK_SIZE); dim3 grid((M + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM, (K + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM); - + nvfp4_transpose_kernel<<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(output.data.dptr), M, K); - + NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -493,7 +486,7 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { * * Transposes tile-level scales from rowwise to columnwise format. * Scale values are stored as E4M3 (fp8) in uint8 tensors. - * + * * Input (rowwise_scale_inv): [M_padded, K_tiles] where scales are stored * at every 16th row (i.e., row 0, 16, 32, ... contain the actual scales, * and each row i within a tile block has the same scale as row (i // 16) * 16). @@ -507,61 +500,59 @@ void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { * --------------------------------------------------------------------------- */ __global__ void nvfp4_scale_transpose_kernel( - const uint8_t* __restrict__ input, // [M_padded, K_tiles], E4M3 stored as uint8 - uint8_t* __restrict__ output, // [K_padded, M_tiles], E4M3 stored as uint8 - const size_t M_tiles, // Number of M tiles - const size_t K_tiles, // Number of K tiles - const size_t input_stride, // K_tiles (input row stride) - const size_t output_stride, // M_tiles (output row stride) - const size_t K_padded // Output height + const uint8_t *__restrict__ input, // [M_padded, K_tiles], E4M3 stored as uint8 + uint8_t *__restrict__ output, // [K_padded, M_tiles], E4M3 stored as uint8 + const size_t M_tiles, // Number of M tiles + const size_t K_tiles, // Number of K tiles + const size_t input_stride, // K_tiles (input row stride) + const size_t output_stride, // M_tiles (output row stride) + const size_t K_padded // Output height ) { - // Each thread handles one output element - const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; - const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; - - if (out_row >= K_padded || out_col >= M_tiles) return; - - // Determine which tile row this belongs to - const size_t k_tile = out_row / kTileDim; - - // Read from input: row = m_tile * 16 (first row of the tile), col = k_tile - // m_tile = out_col - if (k_tile < K_tiles) { - const size_t in_row = out_col * kTileDim; // m_tile * 16 - const uint8_t scale = input[in_row * input_stride + k_tile]; - output[out_row * output_stride + out_col] = scale; - } else { - output[out_row * output_stride + out_col] = 0; - } + // Each thread handles one output element + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= K_padded || out_col >= M_tiles) return; + + // Determine which tile row this belongs to + const size_t k_tile = out_row / kTileDim; + + // Read from input: row = m_tile * 16 (first row of the tile), col = k_tile + // m_tile = out_col + if (k_tile < K_tiles) { + const size_t in_row = out_col * kTileDim; // m_tile * 16 + const uint8_t scale = input[in_row * input_stride + k_tile]; + output[out_row * output_stride + out_col] = scale; + } else { + output[out_row * output_stride + out_col] = 0; + } } -void nvfp4_scale_transpose(const Tensor input, Tensor output, - size_t M_tiles, size_t K_tiles, +void nvfp4_scale_transpose(const Tensor input, Tensor output, size_t M_tiles, size_t K_tiles, cudaStream_t stream) { - NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); - NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); - - const auto in_shape = input.shape(); - const auto out_shape = output.shape(); - NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); - NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); - - const size_t input_stride = in_shape[1]; // K_tiles - const size_t output_stride = out_shape[1]; // M_tiles - const size_t K_padded = out_shape[0]; - - if (M_tiles == 0 || K_tiles == 0 || K_padded == 0) return; - - constexpr int kBlockDim = 16; - dim3 block(kBlockDim, kBlockDim); - dim3 grid((M_tiles + kBlockDim - 1) / kBlockDim, - (K_padded + kBlockDim - 1) / kBlockDim); - - nvfp4_scale_transpose_kernel<<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output.data.dptr), - M_tiles, K_tiles, input_stride, output_stride, K_padded); - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + + const auto in_shape = input.shape(); + const auto out_shape = output.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + + const size_t input_stride = in_shape[1]; // K_tiles + const size_t output_stride = out_shape[1]; // M_tiles + const size_t K_padded = out_shape[0]; + + if (M_tiles == 0 || K_tiles == 0 || K_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((M_tiles + kBlockDim - 1) / kBlockDim, (K_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_scale_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M_tiles, K_tiles, input_stride, output_stride, + K_padded); + NVTE_CHECK_CUDA(cudaGetLastError()); } /* @@ -569,7 +560,7 @@ void nvfp4_scale_transpose(const Tensor input, Tensor output, * NVFP4 SCALE EXPANSION KERNEL * * Expands tile-level scales to row-level scales and converts to FP8 E4M3, used in partial cast. - * + * * Input (per_block_decode_scale): [tile_rows, tile_cols] in float32 * Output (target_scale): [rows_padded, tile_cols] in uint8 (E4M3) * @@ -577,51 +568,45 @@ void nvfp4_scale_transpose(const Tensor input, Tensor output, * --------------------------------------------------------------------------- */ __global__ void nvfp4_expand_scale_to_fp8_kernel( - const float* __restrict__ input, // [tile_rows, tile_cols] - uint8_t* __restrict__ output, // [rows_padded, tile_cols] - const size_t tile_rows, - const size_t tile_cols, - const size_t rows_padded, - const size_t block_len -) { - const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; - const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; - - if (out_row >= rows_padded || out_col >= tile_cols) return; - - // Determine which tile row this output row belongs to - const size_t tile_row = out_row / block_len; - - float scale_val = 0.0f; - if (tile_row < tile_rows) { - scale_val = input[tile_row * tile_cols + out_col]; - } - - // Convert float32 to FP8 E4M3 - // Clamp to FP8 E4M3 range and convert - fp8e4m3 fp8_val = static_cast(scale_val); - output[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); + const float *__restrict__ input, // [tile_rows, tile_cols] + uint8_t *__restrict__ output, // [rows_padded, tile_cols] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + scale_val = input[tile_row * tile_cols + out_col]; + } + + // Convert float32 to FP8 E4M3 + // Clamp to FP8 E4M3 range and convert + fp8e4m3 fp8_val = static_cast(scale_val); + output[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); } -void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, +void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, cudaStream_t stream) { - NVTE_CHECK(input.dtype() == DType::kFloat32, "Scale input must be float32."); - NVTE_CHECK(output.dtype() == DType::kByte, "Scale output must be uint8 (E4M3)."); - - if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; - - constexpr int kBlockDim = 16; - dim3 block(kBlockDim, kBlockDim); - dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, - (rows_padded + kBlockDim - 1) / kBlockDim); - - nvfp4_expand_scale_to_fp8_kernel<<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output.data.dptr), - tile_rows, tile_cols, rows_padded, block_len); - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(input.dtype() == DType::kFloat32, "Scale input must be float32."); + NVTE_CHECK(output.dtype() == DType::kByte, "Scale output must be uint8 (E4M3)."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_expand_scale_to_fp8_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), tile_rows, tile_cols, rows_padded, block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } /* @@ -642,90 +627,86 @@ void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, * --------------------------------------------------------------------------- */ __global__ void nvfp4_compute_per_block_scale_kernel( - const float* __restrict__ block_amax, // [tile_rows, tile_cols] - float* __restrict__ scale, // [tile_rows, tile_cols] - const float* __restrict__ global_amax_ptr, // Pointer to single float value (avoids D2H) - const size_t numel -) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= numel) return; - - constexpr float fp4_max = 6.0f; - constexpr float fp8_max = 448.0f; - constexpr float flt_max = 3.402823466e+38f; - constexpr float tiny = 1.17549435e-38f; // FLT_MIN - - // Read global_amax from device memory (avoids D2H transfer) - float global_amax = *global_amax_ptr; - - // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax - float safe_global_amax = fmaxf(global_amax, tiny); - float global_scale = (global_amax > 0.0f) ? - fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; - - // Compute per-block decode scale: S_dec_b = block_amax / fp4_max * S_enc - float amax_val = block_amax[idx]; - float result = fminf((amax_val / fp4_max) * global_scale, flt_max); - scale[idx] = result; + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + float *__restrict__ scale, // [tile_rows, tile_cols] + const float *__restrict__ global_amax_ptr, // Pointer to single float value (avoids D2H) + const size_t numel) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + // Read global_amax from device memory (avoids D2H transfer) + float global_amax = *global_amax_ptr; + + // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax + float safe_global_amax = fmaxf(global_amax, tiny); + float global_scale = + (global_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Compute per-block decode scale: S_dec_b = block_amax / fp4_max * S_enc + float amax_val = block_amax[idx]; + float result = fminf((amax_val / fp4_max) * global_scale, flt_max); + scale[idx] = result; } // Simple kernel to compute global encode scale from global amax __global__ void nvfp4_compute_global_scale_kernel( - const float* __restrict__ global_amax, // [num_params] - float* __restrict__ global_scale, // [num_params] - const size_t num_params -) { - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_params) return; - - constexpr float fp4_max = 6.0f; - constexpr float fp8_max = 448.0f; - constexpr float flt_max = 3.402823466e+38f; - constexpr float tiny = 1.17549435e-38f; // FLT_MIN - - float amax = global_amax[idx]; - float safe_amax = fmaxf(amax, tiny); - float scale = (amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_amax, flt_max) : 1.0f; - global_scale[idx] = scale; + const float *__restrict__ global_amax, // [num_params] + float *__restrict__ global_scale, // [num_params] + const size_t num_params) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_params) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + float amax = global_amax[idx]; + float safe_amax = fmaxf(amax, tiny); + float scale = (amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_amax, flt_max) : 1.0f; + global_scale[idx] = scale; } -void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, - const Tensor global_amax, cudaStream_t stream) { - NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); - NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); - NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); - NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); - - size_t numel = block_amax.numel(); - if (numel == 0) return; - - constexpr int kBlockSize = 256; - int grid_size = (numel + kBlockSize - 1) / kBlockSize; - - nvfp4_compute_per_block_scale_kernel<<>>( - reinterpret_cast(block_amax.data.dptr), - reinterpret_cast(scale.data.dptr), - reinterpret_cast(global_amax.data.dptr), - numel); - NVTE_CHECK_CUDA(cudaGetLastError()); +void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, const Tensor global_amax, + cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + + size_t numel = block_amax.numel(); + if (numel == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (numel + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_per_block_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(scale.data.dptr), + reinterpret_cast(global_amax.data.dptr), numel); + NVTE_CHECK_CUDA(cudaGetLastError()); } void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, cudaStream_t stream) { - NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); - NVTE_CHECK(global_scale.dtype() == DType::kFloat32, "Global scale must be float32."); - - size_t num_params = global_amax.numel(); - if (num_params == 0) return; - - constexpr int kBlockSize = 256; - int grid_size = (num_params + kBlockSize - 1) / kBlockSize; - - nvfp4_compute_global_scale_kernel<<>>( - reinterpret_cast(global_amax.data.dptr), - reinterpret_cast(global_scale.data.dptr), - num_params); - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_scale.dtype() == DType::kFloat32, "Global scale must be float32."); + + size_t num_params = global_amax.numel(); + if (num_params == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (num_params + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_global_scale_kernel<<>>( + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(global_scale.data.dptr), num_params); + NVTE_CHECK_CUDA(cudaGetLastError()); } /* @@ -747,65 +728,60 @@ void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, * nvfp4_expand_scale_to_fp8 as separate calls, plus the amax copy). * --------------------------------------------------------------------------- */ - __global__ void nvfp4_fused_scale_kernel( - const float* __restrict__ block_amax, // [tile_rows, tile_cols] - const float* __restrict__ global_amax, // [1] - float* __restrict__ per_block_scale, // [tile_rows, tile_cols] - for partial_cast - uint8_t* __restrict__ target_scale, // [rows_padded, tile_cols] - float* __restrict__ target_amax, // [1] - const size_t tile_rows, - const size_t tile_cols, - const size_t rows_padded, - const size_t block_len -) { +__global__ void nvfp4_fused_scale_kernel( + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + const float *__restrict__ global_amax, // [1] + float *__restrict__ per_block_scale, // [tile_rows, tile_cols] - for partial_cast + uint8_t *__restrict__ target_scale, // [rows_padded, tile_cols] + float *__restrict__ target_amax, // [1] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; - + // Read global amax once per thread (broadcast) const float g_amax = *global_amax; - + // Thread (0,0) copies global_amax to target_amax if (out_row == 0 && out_col == 0) { - *target_amax = g_amax; + *target_amax = g_amax; } - + if (out_row >= rows_padded || out_col >= tile_cols) return; - + // Determine which tile row this output row belongs to const size_t tile_row = out_row / block_len; - + // Compute the scale value constexpr float fp4_max = 6.0f; constexpr float fp8_max = 448.0f; constexpr float flt_max = 3.402823466e+38f; constexpr float tiny = 1.17549435e-38f; - + float scale_val = 0.0f; if (tile_row < tile_rows) { - float safe_global_amax = fmaxf(g_amax, tiny); - float global_scale = (g_amax > 0.0f) ? - fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; - - // Read block amax and compute per-block decode scale - float amax_val = block_amax[tile_row * tile_cols + out_col]; - scale_val = fminf((amax_val / fp4_max) * global_scale, flt_max); - - // Write per-block scale (only once per tile, when out_row % block_len == 0) - if (out_row % block_len == 0) { - per_block_scale[tile_row * tile_cols + out_col] = scale_val; - } + float safe_global_amax = fmaxf(g_amax, tiny); + float global_scale = + (g_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Read block amax and compute per-block decode scale + float amax_val = block_amax[tile_row * tile_cols + out_col]; + scale_val = fminf((amax_val / fp4_max) * global_scale, flt_max); + + // Write per-block scale (only once per tile, when out_row % block_len == 0) + if (out_row % block_len == 0) { + per_block_scale[tile_row * tile_cols + out_col] = scale_val; + } } - + // Convert float32 to FP8 E4M3 and write expanded scale fp8e4m3 fp8_val = static_cast(scale_val); - target_scale[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); + target_scale[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); } -void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, - Tensor per_block_scale, Tensor target_scale, Tensor target_amax, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, - cudaStream_t stream) { +void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, Tensor per_block_scale, + Tensor target_scale, Tensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); NVTE_CHECK(per_block_scale.dtype() == DType::kFloat32, "Per-block scale must be float32."); @@ -813,63 +789,58 @@ void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, NVTE_CHECK(target_amax.dtype() == DType::kFloat32, "Target amax must be float32."); NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); - + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; - + constexpr int kBlockDim = 16; dim3 block(kBlockDim, kBlockDim); - dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, - (rows_padded + kBlockDim - 1) / kBlockDim); - + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + nvfp4_fused_scale_kernel<<>>( - reinterpret_cast(block_amax.data.dptr), - reinterpret_cast(global_amax.data.dptr), - reinterpret_cast(per_block_scale.data.dptr), - reinterpret_cast(target_scale.data.dptr), - reinterpret_cast(target_amax.data.dptr), - tile_rows, tile_cols, rows_padded, block_len); + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(per_block_scale.data.dptr), + reinterpret_cast(target_scale.data.dptr), + reinterpret_cast(target_amax.data.dptr), tile_rows, tile_cols, rows_padded, + block_len); NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace nvfp4_recipe } // namespace transformer_engine -void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, cudaStream_t stream) { - NVTE_API_CALL(nvte_nvfp4_expand_scale_to_fp8); - using namespace transformer_engine; - nvfp4_recipe::nvfp4_expand_scale_to_fp8(*convertNVTETensorCheck(input), - *convertNVTETensorCheck(output), - tile_rows, tile_cols, rows_padded, block_len, stream); + NVTE_API_CALL(nvte_nvfp4_expand_scale_to_fp8); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_expand_scale_to_fp8(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), tile_rows, tile_cols, + rows_padded, block_len, stream); } void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, const NVTETensor global_amax, cudaStream_t stream) { - NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); - using namespace transformer_engine; - nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), - *convertNVTETensorCheck(scale), - *convertNVTETensorCheck(global_amax), - stream); + NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_amax), stream); } void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, cudaStream_t stream) { - NVTE_API_CALL(nvte_nvfp4_compute_global_scale); - using namespace transformer_engine; - nvfp4_recipe::nvfp4_compute_global_scale(*convertNVTETensorCheck(global_amax), - *convertNVTETensorCheck(global_scale), - stream); + NVTE_API_CALL(nvte_nvfp4_compute_global_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_global_scale(*convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(global_scale), stream); } -void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, - size_t M_tiles, size_t K_tiles, cudaStream_t stream) { - NVTE_API_CALL(nvte_nvfp4_scale_transpose); - using namespace transformer_engine; - nvfp4_recipe::nvfp4_scale_transpose(*convertNVTETensorCheck(input), - *convertNVTETensorCheck(output), - M_tiles, K_tiles, stream); +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_scale_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_scale_transpose(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), M_tiles, K_tiles, stream); } void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -879,15 +850,15 @@ void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_ stream); } -void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, - size_t w, size_t amax_stride_h, - size_t amax_stride_w, size_t start_offset, - size_t block_len, cudaStream_t stream) { +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, + cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_2d_compute_partial_amax); using namespace transformer_engine; - nvfp4_recipe::nvfp4_2d_compute_partial_amax( - *convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h, - amax_stride_w, start_offset, block_len, stream); + nvfp4_recipe::nvfp4_2d_compute_partial_amax(*convertNVTETensorCheck(inp), + *convertNVTETensorCheck(amax), h, w, amax_stride_h, + amax_stride_w, start_offset, block_len, stream); } void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, @@ -896,10 +867,10 @@ void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTE size_t block_len, cudaStream_t stream) { NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); using namespace transformer_engine; - nvfp4_recipe::nvfp4_2d_partial_cast( - *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), - *convertNVTETensorCheck(global_scale), h, w, scale_stride_h, scale_stride_w, start_offset, - block_len, stream); + nvfp4_recipe::nvfp4_2d_partial_cast(*convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_scale), h, w, scale_stride_h, + scale_stride_w, start_offset, block_len, stream); } void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, @@ -930,17 +901,12 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, NVTETensor per_block_scale, NVTETensor target_scale, - NVTETensor target_amax, - size_t tile_rows, size_t tile_cols, - size_t rows_padded, size_t block_len, - cudaStream_t stream) { - NVTE_API_CALL(nvte_nvfp4_fused_scale); - using namespace transformer_engine; - nvfp4_recipe::nvfp4_fused_scale(*convertNVTETensorCheck(block_amax), - *convertNVTETensorCheck(global_amax), - *convertNVTETensorCheck(per_block_scale), - *convertNVTETensorCheck(target_scale), - *convertNVTETensorCheck(target_amax), - tile_rows, tile_cols, rows_padded, block_len, - stream); + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_fused_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_fused_scale( + *convertNVTETensorCheck(block_amax), *convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(per_block_scale), *convertNVTETensorCheck(target_scale), + *convertNVTETensorCheck(target_amax), tile_rows, tile_cols, rows_padded, block_len, stream); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 488b3984d3..29674ea9ef 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -161,45 +161,31 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output = void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, int64_t K_tiles); -void nvfp4_multi_tensor_create_columnwise( - std::vector rowwise_data_list, - std::vector columnwise_data_list, - std::vector rowwise_scale_inv_list, - std::vector columnwise_scale_inv_list, - std::vector M_list, - std::vector K_list); +void nvfp4_multi_tensor_create_columnwise(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, std::vector K_list); void nvfp4_multi_tensor_compute_partial_amax( - std::vector master_weight_list, - std::vector partial_amax_list, - std::vector global_amax_list, - std::vector h_list, - std::vector w_list, - std::vector start_offset_list, - int64_t block_len); - -void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, - int64_t tile_rows, int64_t tile_cols, - int64_t rows_padded, int64_t block_len); + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len); + +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax); -void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, - at::Tensor per_block_scale, at::Tensor target_scale, - at::Tensor target_amax, - int64_t tile_rows, int64_t tile_cols, - int64_t rows_padded, int64_t block_len); +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); void nvfp4_multi_tensor_fused_scale( - std::vector block_amax_list, - std::vector global_amax_list, - std::vector per_block_scale_list, - std::vector target_scale_list, - std::vector target_amax_list, - std::vector tile_rows_list, - std::vector tile_cols_list, - std::vector rows_padded_list, - int64_t block_len); + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len); void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); @@ -395,15 +381,12 @@ void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tens const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, size_t block_len); -void nvfp4_multi_tensor_2d_partial_cast( - std::vector inp_list, - std::vector out_list, - std::vector scale_list, - std::vector global_scale_list, - std::vector h_list, - std::vector w_list, - std::vector start_offset_list, - int64_t block_len); +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len); void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, at::Tensor amax_colwise, int rows, int cols, size_t start_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp index 2c73c577e2..6e1eb4483c 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -9,7 +9,7 @@ namespace transformer_engine::pytorch { -void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, +void nvfp4_2d_compute_partial_amax(const at::Tensor& tensor, at::Tensor amax, size_t h, size_t w, size_t start_offset, size_t block_len) { TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); @@ -21,13 +21,13 @@ void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, si const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor.contiguous()); TensorWrapper amax_cu = makeTransformerEngineTensor(amax); - nvte_nvfp4_2d_compute_partial_amax( - tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), amax.stride(1), start_offset, - block_len, at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), + amax.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); } -void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, - const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, +void nvfp4_2d_partial_cast(const at::Tensor& inp, py::handle out, const at::Tensor& scale, + const at::Tensor& global_scale, size_t h, size_t w, size_t start_offset, size_t block_len) { TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); @@ -35,30 +35,26 @@ void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tens TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, "global_scale must be a float tensor"); - TORCH_CHECK(inp.scalar_type() == at::ScalarType::Float || - inp.scalar_type() == at::ScalarType::BFloat16, - "input must be a float or bfloat16 tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); - nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), - global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), - start_offset, block_len, + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), global_scale_cu.data(), + h, w, scale.stride(0), scale.stride(1), start_offset, block_len, at::cuda::getCurrentCUDAStream()); } -void nvfp4_multi_tensor_2d_partial_cast( - std::vector inp_list, - std::vector out_list, - std::vector scale_list, - std::vector global_scale_list, - std::vector h_list, - std::vector w_list, - std::vector start_offset_list, - int64_t block_len) { +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len) { TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); const size_t num_tensors = inp_list.size(); @@ -89,9 +85,9 @@ void nvfp4_multi_tensor_2d_partial_cast( TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, "global_scale must be a float tensor"); - TORCH_CHECK(inp.scalar_type() == at::ScalarType::Float || - inp.scalar_type() == at::ScalarType::BFloat16, - "input must be a float or bfloat16 tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); const TensorWrapper out_cu = makeTransformerEngineTensor(out); @@ -105,16 +101,11 @@ void nvfp4_multi_tensor_2d_partial_cast( } void nvfp4_multi_tensor_compute_partial_amax( - std::vector master_weight_list, - std::vector partial_amax_list, - std::vector global_amax_list, - std::vector h_list, - std::vector w_list, - std::vector start_offset_list, - int64_t block_len) { - + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len) { TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); - + const size_t num_tensors = master_weight_list.size(); TORCH_CHECK(partial_amax_list.size() == num_tensors, "partial_amax_list size mismatch"); TORCH_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); @@ -150,22 +141,17 @@ void nvfp4_multi_tensor_compute_partial_amax( const TensorWrapper tensor_cu = makeTransformerEngineTensor(master_weight.contiguous()); TensorWrapper amax_cu = makeTransformerEngineTensor(partial_amax); - nvte_nvfp4_2d_compute_partial_amax( - tensor_cu.data(), amax_cu.data(), h, w, - partial_amax.stride(0), partial_amax.stride(1), - start_offset, static_cast(block_len), stream); + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, + partial_amax.stride(0), partial_amax.stride(1), start_offset, + static_cast(block_len), stream); // Compute global amax auto* global_amax_ptr = global_amax.data_ptr(); TensorWrapper fake_te_output( - /*dptr=*/nullptr, tensor_cu.shape(), - DType::kFloat32, - global_amax_ptr); + /*dptr=*/nullptr, tensor_cu.shape(), DType::kFloat32, global_amax_ptr); nvte_compute_amax(tensor_cu.data(), fake_te_output.data(), stream); } } } // namespace transformer_engine::pytorch - - diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a2e7690ef1..f5c6c39234 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -263,44 +263,41 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvfp4_transpose", &transformer_engine::pytorch::nvfp4_transpose, "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), py::arg("out"), py::call_guard()); - m.def("nvfp4_scale_transpose", &transformer_engine::pytorch::nvfp4_scale_transpose, - "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", - py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), - py::call_guard()); + m.def( + "nvfp4_scale_transpose", &transformer_engine::pytorch::nvfp4_scale_transpose, + "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", + py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), + py::call_guard()); m.def("nvfp4_expand_scale_to_fp8", &transformer_engine::pytorch::nvfp4_expand_scale_to_fp8, - "Expand tile-level scales to row-level scales and convert to FP8 E4M3", - py::arg("input"), py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), - py::arg("rows_padded"), py::arg("block_len"), - py::call_guard()); - m.def("nvfp4_compute_per_block_scale", &transformer_engine::pytorch::nvfp4_compute_per_block_scale, - "Compute per-block decode scale from block amax and global amax", - py::arg("block_amax"), py::arg("scale"), py::arg("global_amax"), - py::call_guard()); + "Expand tile-level scales to row-level scales and convert to FP8 E4M3", py::arg("input"), + py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), + py::arg("block_len"), py::call_guard()); + m.def("nvfp4_compute_per_block_scale", + &transformer_engine::pytorch::nvfp4_compute_per_block_scale, + "Compute per-block decode scale from block amax and global amax", py::arg("block_amax"), + py::arg("scale"), py::arg("global_amax"), py::call_guard()); m.def("nvfp4_compute_global_scale", &transformer_engine::pytorch::nvfp4_compute_global_scale, - "Compute global encode scale from global amax", - py::arg("global_amax"), py::arg("global_scale"), - py::call_guard()); + "Compute global encode scale from global amax", py::arg("global_amax"), + py::arg("global_scale"), py::call_guard()); m.def("nvfp4_fused_scale", &transformer_engine::pytorch::nvfp4_fused_scale, "Fused kernel: compute per-block decode scale, copy global amax, expand to row-level FP8", py::arg("block_amax"), py::arg("global_amax"), py::arg("per_block_scale"), - py::arg("target_scale"), py::arg("target_amax"), - py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), py::arg("block_len"), - py::call_guard()); + py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), + py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); m.def("nvfp4_multi_tensor_fused_scale", &transformer_engine::pytorch::nvfp4_multi_tensor_fused_scale, - "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for multiple tensors", + "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for " + "multiple tensors", py::arg("block_amax_list"), py::arg("global_amax_list"), py::arg("per_block_scale_list"), - py::arg("target_scale_list"), py::arg("target_amax_list"), - py::arg("tile_rows_list"), py::arg("tile_cols_list"), py::arg("rows_padded_list"), - py::arg("block_len"), + py::arg("target_scale_list"), py::arg("target_amax_list"), py::arg("tile_rows_list"), + py::arg("tile_cols_list"), py::arg("rows_padded_list"), py::arg("block_len"), py::call_guard()); m.def("nvfp4_multi_tensor_create_columnwise", &transformer_engine::pytorch::nvfp4_multi_tensor_create_columnwise, "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", py::arg("rowwise_data_list"), py::arg("columnwise_data_list"), - py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), - py::arg("M_list"), py::arg("K_list"), - py::call_guard()); + py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), py::arg("M_list"), + py::arg("K_list"), py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); @@ -334,18 +331,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Batched compute partial and global amax from master weights for NVFP4 2D", py::arg("master_weight_list"), py::arg("partial_amax_list"), py::arg("global_amax_list"), py::arg("h_list"), py::arg("w_list"), py::arg("start_offset_list"), - py::arg("block_len") = 16, - py::call_guard()); + py::arg("block_len") = 16, py::call_guard()); m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), - py::arg("start_offset"), py::arg("block_len") = 16, py::call_guard()); + py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); m.def("nvfp4_multi_tensor_2d_partial_cast", &transformer_engine::pytorch::nvfp4_multi_tensor_2d_partial_cast, - "Batched partial cast from master weights for NVFP4 2D", - py::arg("inp_list"), py::arg("out_list"), py::arg("scale_list"), - py::arg("global_scale_list"), py::arg("h_list"), py::arg("w_list"), - py::arg("start_offset_list"), py::arg("block_len") = 16, + "Batched partial cast from master weights for NVFP4 2D", py::arg("inp_list"), + py::arg("out_list"), py::arg("scale_list"), py::arg("global_scale_list"), py::arg("h_list"), + py::arg("w_list"), py::arg("start_offset_list"), py::arg("block_len") = 16, py::call_guard()); m.def("mxfp8_scaling_compute_partial_amax", &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index cc79223439..dac2621802 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -5,12 +5,11 @@ ************************************************************************/ #include +#include #include #include -#include - #include "../extensions.h" #include "pybind.h" @@ -76,9 +75,9 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output) { at::Tensor out; if (output.has_value()) { out = *output; - NVTE_CHECK(static_cast(out.size(0)) == K && - static_cast(out.size(1)) == M_packed, - "Output shape mismatch for NVFP4 transpose."); + NVTE_CHECK( + static_cast(out.size(0)) == K && static_cast(out.size(1)) == M_packed, + "Output shape mismatch for NVFP4 transpose."); } else { const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); out = at::empty(output_shape, opts); @@ -90,17 +89,16 @@ at::Tensor nvfp4_transpose(at::Tensor input, std::optional output) { } // Call the NVFP4 transpose kernel - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, - DType::kByte); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, - DType::kByte); + auto input_cu = + makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, DType::kByte); + auto output_cu = + makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, DType::kByte); nvte_nvfp4_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; } -void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, - int64_t M_tiles, int64_t K_tiles) { +void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, int64_t K_tiles) { init_extension(); // Input: rowwise_scale_inv [M_padded, K_tiles], uint8 (E4M3 stored as bytes) @@ -110,21 +108,20 @@ void nvfp4_scale_transpose(at::Tensor input, at::Tensor output, NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); NVTE_CHECK(input.scalar_type() == at::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); - NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + NVTE_CHECK(output.scalar_type() == at::kByte, + "NVFP4 scale transpose output must be uint8 (E4M3)."); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kByte); auto output_cu = makeTransformerEngineTensor( output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); - nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), - static_cast(M_tiles), static_cast(K_tiles), - at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), static_cast(M_tiles), + static_cast(K_tiles), at::cuda::getCurrentCUDAStream()); } -void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, - int64_t tile_rows, int64_t tile_cols, - int64_t rows_padded, int64_t block_len) { +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { init_extension(); // Input: per_block_decode_scale [tile_rows, tile_cols], float32 @@ -141,15 +138,13 @@ void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, auto output_cu = makeTransformerEngineTensor( output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); - nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), - static_cast(tile_rows), - static_cast(tile_cols), - static_cast(rows_padded), - static_cast(block_len), - at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), static_cast(tile_rows), + static_cast(tile_cols), static_cast(rows_padded), + static_cast(block_len), at::cuda::getCurrentCUDAStream()); } -void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax) { +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, + at::Tensor global_amax) { init_extension(); // block_amax and scale: [tile_rows, tile_cols], float32 @@ -163,15 +158,13 @@ void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at:: auto scale_cu = makeTransformerEngineTensor(scale); auto global_amax_cu = makeTransformerEngineTensor(global_amax); - nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), - global_amax_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), global_amax_cu.data(), + at::cuda::getCurrentCUDAStream()); } -void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, - at::Tensor per_block_scale, at::Tensor target_scale, - at::Tensor target_amax, - int64_t tile_rows, int64_t tile_cols, - int64_t rows_padded, int64_t block_len) { +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { init_extension(); // block_amax: [tile_rows, tile_cols], float32 @@ -193,24 +186,18 @@ void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, auto target_scale_cu = makeTransformerEngineTensor(target_scale); auto target_amax_cu = makeTransformerEngineTensor(target_amax); - nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), - per_block_scale_cu.data(), target_scale_cu.data(), - target_amax_cu.data(), + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), static_cast(tile_rows), static_cast(tile_cols), static_cast(rows_padded), static_cast(block_len), at::cuda::getCurrentCUDAStream()); } void nvfp4_multi_tensor_fused_scale( - std::vector block_amax_list, - std::vector global_amax_list, - std::vector per_block_scale_list, - std::vector target_scale_list, - std::vector target_amax_list, - std::vector tile_rows_list, - std::vector tile_cols_list, - std::vector rows_padded_list, - int64_t block_len) { + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len) { init_extension(); const size_t num_tensors = block_amax_list.size(); @@ -252,11 +239,9 @@ void nvfp4_multi_tensor_fused_scale( auto target_scale_cu = makeTransformerEngineTensor(target_scale); auto target_amax_cu = makeTransformerEngineTensor(target_amax); - nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), - per_block_scale_cu.data(), target_scale_cu.data(), - target_amax_cu.data(), - tile_rows, tile_cols, rows_padded, - static_cast(block_len), stream); + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), tile_rows, tile_cols, + rows_padded, static_cast(block_len), stream); } } @@ -278,7 +263,7 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); // Make sure input is contiguous - const auto &input = tensor.contiguous(); + const auto& input = tensor.contiguous(); // Allocate output tensor if needed if (!out) { @@ -299,13 +284,12 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { return std::move(*out); } -void nvfp4_multi_tensor_create_columnwise( - std::vector rowwise_data_list, - std::vector columnwise_data_list, - std::vector rowwise_scale_inv_list, - std::vector columnwise_scale_inv_list, - std::vector M_list, - std::vector K_list) { +void nvfp4_multi_tensor_create_columnwise(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, + std::vector K_list) { init_extension(); const size_t num_tensors = rowwise_data_list.size(); @@ -355,14 +339,14 @@ void nvfp4_multi_tensor_create_columnwise( const auto scale_out_shape = getTensorShape(columnwise_scale_inv); auto scale_input_cu = makeTransformerEngineTensor( - rowwise_scale_inv.data_ptr(), - std::vector{scale_in_shape[0], scale_in_shape[1]}, DType::kByte); + rowwise_scale_inv.data_ptr(), std::vector{scale_in_shape[0], scale_in_shape[1]}, + DType::kByte); auto scale_output_cu = makeTransformerEngineTensor( columnwise_scale_inv.data_ptr(), std::vector{scale_out_shape[0], scale_out_shape[1]}, DType::kByte); - nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), - M_tiles, K_tiles, stream); + nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), M_tiles, K_tiles, + stream); } } diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index db8bbd74c0..b7fa483766 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -365,7 +365,7 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None - + def _create_columnwise(self): """ Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. @@ -388,7 +388,7 @@ def _create_columnwise(self): ) assert len(self._rowwise_scale_inv.shape) == 2 assert len(self._columnwise_scale_inv.shape) == 2 - + # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale # is repeated 16 times (once per row in the 16x16 tile). # columnwise_scale_inv has shape [K_padded, M_tiles] where scales are @@ -398,7 +398,7 @@ def _create_columnwise(self): M, K = logical_shape[0], logical_shape[-1] M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE - + tex.nvfp4_scale_transpose( self._rowwise_scale_inv, self._columnwise_scale_inv, @@ -410,4 +410,4 @@ def _create_columnwise(self): if self._amax_columnwise is None and self._amax_rowwise is not None: self._amax_columnwise = self._amax_rowwise.clone() elif self._amax_rowwise is not None: - self._amax_columnwise.copy_(self._amax_rowwise) \ No newline at end of file + self._amax_columnwise.copy_(self._amax_rowwise) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 09ff61c8db..5e0348bae9 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -165,8 +165,12 @@ def cast_master_weights_to_fp8( def cast_master_weights_to_nvfp4( - model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None, - manual_post_all_gather_processing=False + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False, ): """Helper to cast master weights to NVFP4 primary weights.""" @@ -181,7 +185,7 @@ def cast_master_weights_to_nvfp4( # All NVFP4 model_weights should have the same dtype (BF16) if len(model_weights) > 0: target_dtype = model_weights[0].dtype - + # Collect non-None master_weights and their indices non_none_indices = [] non_none_weights = [] @@ -191,17 +195,18 @@ def cast_master_weights_to_nvfp4( non_none_indices.append(i) non_none_weights.append(mw.view(-1)) sizes.append(mw.numel()) - + if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: # Concatenate, convert once, then split concatenated = torch.cat(non_none_weights) converted = concatenated.to(target_dtype) split_weights = torch.split(converted, sizes) - + # Rebuild master_weights list with converted tensors converted_master_weights = list(master_weights) - for idx, split_w, orig_mw in zip(non_none_indices, split_weights, - [master_weights[i] for i in non_none_indices]): + for idx, split_w, orig_mw in zip( + non_none_indices, split_weights, [master_weights[i] for i in non_none_indices] + ): converted_master_weights[idx] = split_w.view(orig_mw.shape) master_weights = converted_master_weights @@ -218,14 +223,18 @@ def cast_master_weights_to_nvfp4( ) else: raise ValueError( - f"cast_master_weights_to_nvfp4 only supports NVFP4 tensors, got {type(model_weight)}" + "cast_master_weights_to_nvfp4 only supports NVFP4 tensors, got" + f" {type(model_weight)}" ) if len(nvfp4_params) > 0: _cast_master_weights_to_nvfp4_2d( - nvfp4_params, group, use_fsdp_shard_model_weights=use_fsdp_shard_model_weights, - manual_post_all_gather_processing=manual_post_all_gather_processing + nvfp4_params, + group, + use_fsdp_shard_model_weights=use_fsdp_shard_model_weights, + manual_post_all_gather_processing=manual_post_all_gather_processing, ) + def _cast_master_weights_to_fp8_delayed_scaling( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): @@ -544,6 +553,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) + def _cast_master_weights_to_nvfp4_2d( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): @@ -557,7 +567,7 @@ def _cast_master_weights_to_nvfp4_2d( group. use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. """ - + device = params[0][0].device block_len = NVFP4_BLOCK_SCALING_SIZE @@ -589,9 +599,7 @@ def _cast_master_weights_to_nvfp4_2d( amaxes: List[torch.Tensor] = [] scales: List[torch.Tensor] = [] global_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) - global_amax_views: List[torch.Tensor] = [ - global_amaxes[i : i + 1] for i in range(len(params)) - ] + global_amax_views: List[torch.Tensor] = [global_amaxes[i : i + 1] for i in range(len(params))] # Collect tensors for batched multi-tensor amax computation master_weight_list: List[torch.Tensor] = [] @@ -644,7 +652,7 @@ def _cast_master_weights_to_nvfp4_2d( # Use GPU kernel to compute global encode scales from global amaxes # This replaces multiple Python tensor operations with a single kernel global_scale_tensor = torch.empty_like(global_amaxes) - + tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] @@ -759,6 +767,7 @@ def _cast_master_weights_to_nvfp4_2d( block_len, ) + def _cast_master_weights_to_fp8_mxfp8_scaling( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): # pylint: disable=unused-argument @@ -890,15 +899,15 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten - Float8Tensor: may need to create a transposed view to match backend GEMM. - Float8BlockwiseQTensor: create column-wise storage. - Plain pytorch tensor: noop. - + For NVFP4 tensors, uses batched multi-tensor processing to reduce CPU overhead. """ if not isinstance(model_weights, list): model_weights = [model_weights] - + # Collect NVFP4 tensors for batched processing nvfp4_tensors = [] - + for model_weight in model_weights: if isinstance(model_weight, Float8Tensor): # Delayed scaling and per-tensor current scaling: if backend does not support @@ -916,7 +925,7 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") - + # Batch process all NVFP4 tensors with multi-tensor approach if nvfp4_tensors: _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors) @@ -928,7 +937,7 @@ def _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors: List[NVFP4Tensor]): Reduces CPU overhead by collecting all tensor metadata and dispatching to C++. """ TILE_SIZE = 16 - + # Prepare tensor lists for batched C++ call rowwise_data_list = [] columnwise_data_list = [] @@ -936,18 +945,18 @@ def _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors: List[NVFP4Tensor]): columnwise_scale_inv_list = [] M_list = [] K_list = [] - + for tensor in nvfp4_tensors: rowwise_data = tensor._rowwise_data if not rowwise_data.is_contiguous(): rowwise_data = rowwise_data.contiguous() tensor._rowwise_data = rowwise_data - + logical_shape = tensor.size() M, K = logical_shape[0], logical_shape[-1] M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE - + # Allocate columnwise_data if needed if tensor._columnwise_data is None: # Output shape: [K, M/2] packed bytes @@ -959,7 +968,7 @@ def _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors: List[NVFP4Tensor]): tensor._columnwise_data = columnwise_data else: columnwise_data = tensor._columnwise_data - + # Allocate columnwise_scale_inv if needed if tensor._columnwise_scale_inv is None: assert tensor._quantizer is not None @@ -972,20 +981,20 @@ def _nvfp4_multi_tensor_create_columnwise(nvfp4_tensors: List[NVFP4Tensor]): tensor._columnwise_scale_inv = columnwise_scale_inv else: columnwise_scale_inv = tensor._columnwise_scale_inv - + rowwise_data_list.append(rowwise_data) columnwise_data_list.append(columnwise_data) rowwise_scale_inv_list.append(tensor._rowwise_scale_inv) columnwise_scale_inv_list.append(columnwise_scale_inv) M_list.append(M) K_list.append(K) - + # Copy amax if needed if tensor._amax_columnwise is None and tensor._amax_rowwise is not None: tensor._amax_columnwise = tensor._amax_rowwise.clone() elif tensor._amax_rowwise is not None: tensor._amax_columnwise.copy_(tensor._amax_rowwise) - + # Dispatch to C++ multi-tensor kernel tex.nvfp4_multi_tensor_create_columnwise( rowwise_data_list,