Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7b37680
Support for NVFP4 primary weights
Nov 6, 2025
002a3e5
conflicts
Nov 10, 2025
2c0750e
partial amax and cast
Nov 18, 2025
c2acb06
add unit tests
Nov 20, 2025
d6bd790
add global amax
Nov 21, 2025
7d59088
fix conflict
Nov 21, 2025
d607fd3
unit test v2
Nov 21, 2025
b4a0084
fix build error
Nov 22, 2025
3776ac4
fix some issues
Nov 25, 2025
7787e0f
use larger problem for cublas gemm,fix some issues
Nov 30, 2025
7fbf724
debug code
Dec 8, 2025
a2465d1
Fix partial amax kernel and Add clear transpose cache
kunlunl Dec 8, 2025
d0b0be2
Merge branch 'fp4_primary_weights' into 'fp4_primary_weights'
Dec 8, 2025
c3c6aa7
create colwise
Dec 8, 2025
3222d48
colwise usage fix
Dec 8, 2025
a6727cf
Fix partial cast numerical bug
kunlunl Dec 9, 2025
91e5ed0
Merge branch 'fp4_primary_weights' into 'fp4_primary_weights'
Dec 9, 2025
481730f
nvfp4 transpose
Dec 9, 2025
882debb
bitwise identical transpose and add multi-gpu test
Dec 9, 2025
21394ce
clean up
Dec 9, 2025
aad419c
fix create columnwise and add debugging and testing code
Dec 10, 2025
9285232
fix and debug
Dec 11, 2025
632f54b
fix loss mismatch, add debugging code
Dec 11, 2025
433b1db
clean up debugging code
Dec 11, 2025
cc1d021
minor
Dec 11, 2025
48945ed
clean up again
Dec 11, 2025
d458d5b
more restrict shape in test
Dec 12, 2025
07a8e98
fix conflict
Dec 17, 2025
d87daf9
nvfp4 scale transpose
Dec 30, 2025
938079d
partial cast optimize
Jan 5, 2026
3ab7e2c
partial cast optimize patch
Jan 5, 2026
f2986d7
fix typo
Jan 5, 2026
0d29a19
reduce kernel launches
Jan 6, 2026
0191035
fused kernel
Jan 7, 2026
27d9465
optimize transpose kernel
Jan 8, 2026
2b45a3f
multi-tensor apply for transpose
Jan 8, 2026
0dff439
multi-tensor for amax
Jan 8, 2026
dc3d495
multi-tensor for partial cast
Jan 8, 2026
91c1f08
Merge branch 'main' into fp4_primary_weights
Jan 21, 2026
deda92e
clean up
Jan 21, 2026
36b8802
fix syntax error
Jan 27, 2026
3c4adf4
bf16 to fp4 quant
Jan 28, 2026
ed65cdf
Revert "bf16 to fp4 quant"
Feb 19, 2026
1c49687
Merge branch 'main' into fp4_primary_weights
Feb 19, 2026
f9d4e89
file permission and minor
Feb 19, 2026
371c229
remove prints
Feb 19, 2026
687c8b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
504 changes: 477 additions & 27 deletions tests/pytorch/distributed/test_cast_master_weights_to_fp8.py

Large diffs are not rendered by default.

138 changes: 138 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,144 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);

/*! \brief Compute tile-level amax for a partial shard of a 2D tensor.
*
* For NVFP4 2D quantization with 16x16 tiles. Computes the maximum absolute
* value within each tile, but only for elements in [start_offset, start_offset + len)
* of the flattened tensor. Used in distributed settings where each rank owns a shard.
*
* \param[in] inp Input tensor (partial shard, high-precision).
* \param[out] amax Output amax buffer [tile_rows, tile_cols], float32.
* \param[in] h Number of rows in the full 2D tensor.
* \param[in] w Number of columns in the full 2D tensor.
* \param[in] amax_stride_h Stride for amax in tile-row dimension.
* \param[in] amax_stride_w Stride for amax in tile-col dimension.
* \param[in] start_offset Starting element offset in the flattened tensor.
* \param[in] block_len Tile dimension (must be 16 for NVFP4 2D).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len, cudaStream_t stream);

/*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization.
*
* Quantizes elements in [start_offset, start_offset + len) of the flattened tensor
* using precomputed per-tile scales. Each 16x16 tile uses its own scale factor.
* Used in distributed settings where each rank casts its owned shard.
*
* \param[in] inp Input tensor (partial shard, high-precision).
* \param[out] out Output NVFP4 packed tensor (2 values per byte).
* \param[in] scale Per-tile scale factors [tile_rows, tile_cols], float32.
* \param[in] global_scale Global scale factor [1], float32.
* \param[in] h Number of rows in the full 2D tensor.
* \param[in] w Number of columns in the full 2D tensor.
* \param[in] scale_stride_h Stride for scale in tile-row dimension.
* \param[in] scale_stride_w Stride for scale in tile-col dimension.
* \param[in] start_offset Starting element offset in the flattened tensor.
* \param[in] block_len Tile dimension (must be 16 for NVFP4 2D).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale,
const NVTETensor global_scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);

/*! \brief Transpose NVFP4 packed data.
*
* Unlike FP8, NVFP4 packs two 4-bit values per byte. This function correctly
* handles the nibble repacking during transpose.
*
* \param[in] input Input tensor with packed FP4 data. Shape: [M, K/2] bytes.
* \param[out] output Output tensor with transposed packed data. Shape: [K, M/2] bytes.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Transpose NVFP4 tile-level scales from rowwise to columnwise format.
*
* Takes rowwise_scale_inv where scales are stored at every 16th row (tile boundaries)
* and produces columnwise_scale_inv where scales are repeated 16 times per tile row.
* Scale values are stored as E4M3 (fp8) in uint8 tensors.
*
* \param[in] input Input tensor with rowwise scales [M_padded, K_tiles], uint8 (E4M3).
* \param[out] output Output tensor with columnwise scales [K_padded, M_tiles], uint8 (E4M3).
* \param[in] M_tiles Number of tiles in M dimension.
* \param[in] K_tiles Number of tiles in K dimension.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles,
size_t K_tiles, cudaStream_t stream);

/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast.
*
* Each tile row's scale is repeated block_len times in the output.
*
* \param[in] input Input tensor with tile scales [tile_rows, tile_cols], float32.
* \param[out] output Output tensor with expanded scales [rows_padded, tile_cols], uint8 (E4M3).
* \param[in] tile_rows Number of tile rows.
* \param[in] tile_cols Number of tile columns.
* \param[in] rows_padded Padded row count in output.
* \param[in] block_len Block length (typically 16 for NVFP4).
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows,
size_t tile_cols, size_t rows_padded, size_t block_len,
cudaStream_t stream);

/*! \brief Compute per-block decode scale from block amax and global amax.
*
* Computes:
* global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax
* per_block_decode_scale = block_amax / fp4_max * global_scale
*
* This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh.
*
* \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32.
* \param[out] scale Output scale tensor [tile_rows, tile_cols], float32.
* \param[in] global_amax Global amax tensor (single element), float32. Avoids D2H transfer.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale,
const NVTETensor global_amax, cudaStream_t stream);

/*! \brief Fused kernel for NVFP4 scale computation.
*
* Fuses three operations into one kernel:
* 1. Compute per-block decode scales from block amax and global amax
* 2. Copy global amax to target tensor
* 3. Expand tile-level scales to row-level and convert to FP8 E4M3
*
* Saves 2 kernel launches per parameter.
*
* \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32.
* \param[in] global_amax Global amax tensor [1], float32.
* \param[out] per_block_scale Output per-block scale [tile_rows, tile_cols], float32 (for partial_cast).
* \param[out] target_scale Output scale tensor [rows_padded, tile_cols], uint8 (E4M3).
* \param[out] target_amax Output amax tensor [1], float32 (copy of global_amax).
* \param[in] tile_rows Number of tile rows.
* \param[in] tile_cols Number of tile columns.
* \param[in] rows_padded Total padded rows in output.
* \param[in] block_len Block length (16 for NVFP4).
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax,
NVTETensor per_block_scale, NVTETensor target_scale,
NVTETensor target_amax, size_t tile_rows, size_t tile_cols,
size_t rows_padded, size_t block_len, cudaStream_t stream);

/*! \brief Compute global encode scale from global amax.
*
* Computes: global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax
* If global_amax <= 0, returns 1.0.
*
* \param[in] global_amax Input global amax tensor [num_params], float32.
* \param[out] global_scale Output global scale tensor [num_params], float32.
* \param[in] stream CUDA stream.
*/
void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Loading