Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
485 changes: 462 additions & 23 deletions tests/pytorch/distributed/test_cast_master_weights_to_fp8.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ list(APPEND transformer_engine_cuda_sources
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)

list(APPEND transformer_engine_cuda_arch_specific_sources
Expand All @@ -182,6 +181,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)

Expand Down
112 changes: 112 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,118 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);

/*! \brief Compute tile-level amax for a partial shard of a 2D tensor.
*
* For NVFP4 2D quantization with 16x16 tiles. Computes the maximum absolute
* value within each tile, but only for elements in [start_offset, start_offset + len)
* of the flattened tensor. Used in distributed settings where each rank owns a shard.
*
* \param[in] inp Input tensor (partial shard, high-precision).
* \param[out] amax Output amax buffer [tile_rows, tile_cols], float32.
* \param[in] h Number of rows in the full 2D tensor.
* \param[in] w Number of columns in the full 2D tensor.
* \param[in] amax_stride_h Stride for amax in tile-row dimension.
* \param[in] amax_stride_w Stride for amax in tile-col dimension.
* \param[in] start_offset Starting element offset in the flattened tensor.
* \param[in] block_len Tile dimension (must be 16 for NVFP4 2D).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len, cudaStream_t stream);

/*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization.
*
* Quantizes elements in [start_offset, start_offset + len) of the flattened tensor
* using precomputed per-tile scales. Each 16x16 tile uses its own scale factor.
* Used in distributed settings where each rank casts its owned shard.
*
* \param[in] inp Input tensor (partial shard, high-precision).
* \param[out] out Output NVFP4 packed tensor (2 values per byte).
* \param[in] scale Per-tile scale factors [tile_rows, tile_cols], float32.
* \param[in] global_scale Global scale factor [1], float32.
* \param[in] h Number of rows in the full 2D tensor.
* \param[in] w Number of columns in the full 2D tensor.
* \param[in] scale_stride_h Stride for scale in tile-row dimension.
* \param[in] scale_stride_w Stride for scale in tile-col dimension.
* \param[in] start_offset Starting element offset in the flattened tensor.
* \param[in] block_len Tile dimension (must be 16 for NVFP4 2D).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale,
const NVTETensor global_scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);

/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast.
*
* Each tile row's scale is repeated block_len times in the output.
*
* \param[in] input Input tensor with tile scales [tile_rows, tile_cols], float32.
* \param[out] output Output tensor with expanded scales [rows_padded, tile_cols], uint8 (E4M3).
* \param[in] tile_rows Number of tile rows.
* \param[in] tile_cols Number of tile columns.
* \param[in] rows_padded Padded row count in output.
* \param[in] block_len Block length (typically 16 for NVFP4).
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows,
size_t tile_cols, size_t rows_padded, size_t block_len,
cudaStream_t stream);

/*! \brief Compute per-block decode scale from block amax and global amax.
*
* Computes:
* global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax
* per_block_decode_scale = block_amax / fp4_max * global_scale
*
* This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh.
*
* \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32.
* \param[out] scale Output scale tensor [tile_rows, tile_cols], float32.
* \param[in] global_amax Global amax tensor (single element), float32. Avoids D2H transfer.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale,
const NVTETensor global_amax, cudaStream_t stream);

/*! \brief Fused kernel for NVFP4 scale computation.
*
* Fuses three operations into one kernel:
* 1. Compute per-block decode scales from block amax and global amax
* 2. Copy global amax to target tensor
* 3. Expand tile-level scales to row-level and convert to FP8 E4M3
*
* Saves 2 kernel launches per parameter.
*
* \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32.
* \param[in] global_amax Global amax tensor [1], float32.
* \param[out] per_block_scale Output per-block scale [tile_rows, tile_cols], float32 (for partial_cast).
* \param[out] target_scale Output scale tensor [rows_padded, tile_cols], uint8 (E4M3).
* \param[out] target_amax Output amax tensor [1], float32 (copy of global_amax).
* \param[in] tile_rows Number of tile rows.
* \param[in] tile_cols Number of tile columns.
* \param[in] rows_padded Total padded rows in output.
* \param[in] block_len Block length (16 for NVFP4).
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax,
NVTETensor per_block_scale, NVTETensor target_scale,
NVTETensor target_amax, size_t tile_rows, size_t tile_cols,
size_t rows_padded, size_t block_len, cudaStream_t stream);

/*! \brief Compute global encode scale from global amax.
*
* Computes: global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax
* If global_amax <= 0, returns 1.0.
*
* \param[in] global_amax Input global amax tensor [num_params], float32.
* \param[out] global_scale Output global scale tensor [num_params], float32.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
26 changes: 26 additions & 0 deletions transformer_engine/common/include/transformer_engine/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,32 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
*/
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Transpose NVFP4 packed data.
*
* Unlike FP8, NVFP4 packs two 4-bit values per byte. This function correctly
* handles the nibble repacking during transpose.
*
* \param[in] input Input tensor with packed FP4 data. Shape: [M, K/2] bytes.
* \param[out] output Output tensor with transposed packed data. Shape: [K, M/2] bytes.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_data_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Transpose NVFP4 tile-level scales from rowwise to columnwise format.
*
* Takes rowwise_scale_inv where scales are stored at every 16th row (tile boundaries)
* and produces columnwise_scale_inv where scales are repeated 16 times per tile row.
* Scale values are stored as E4M3 (fp8) in uint8 tensors.
*
* \param[in] input Input tensor with rowwise scales [M_padded, K_tiles], uint8 (E4M3).
* \param[out] output Output tensor with columnwise scales [K_padded, M_tiles], uint8 (E4M3).
* \param[in] M_tiles Number of tiles in M dimension.
* \param[in] K_tiles Number of tiles in K dimension.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles,
size_t K_tiles, cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Loading
Loading