Skip to content

Comments

NVFP4 primary weight support#2691

Open
WanZzzzzz wants to merge 7 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights
Open

NVFP4 primary weight support#2691
WanZzzzzz wants to merge 7 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights

Conversation

@WanZzzzzz
Copy link

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: qiyuw <qiyuw@nvidia.com>
@WanZzzzzz WanZzzzzz mentioned this pull request Feb 19, 2026
13 tasks
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR adds NVFP4 primary weight support for distributed training with ZeRO/FSDP optimizers, enabling efficient 4-bit quantization with 2D block scaling. The implementation includes custom CUDA kernels for partial cast, transpose (with nibble repacking), and fused scale computation to minimize CPU overhead in large-scale training.

Key Changes

  • NVFP4 Partial Cast Infrastructure: Implements nvfp4_2d_partial_cast_kernel for nibble-accurate updates with two-level scaling (global FP32 + per-block FP8 E4M3)
  • Custom Transpose Kernel: nvfp4_transpose_kernel handles nibble repacking during transpose using vectorized uint2 loads/stores with 64×64 tiles
  • Fused Scale Computation: nvfp4_fused_scale combines per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Multi-Tensor Dispatch: C++-side loop dispatch reduces Python-C++ transition overhead for batched operations
  • CPU Optimizations: Batched dtype conversion via torch.cat/torch.split, replaced torch.zeros() with torch.empty(), consolidated metadata collection
  • Public API: cast_master_weights_to_nvfp4() (via quantize_master_weights()) handles distributed training with coordinated scaling

Issues Found

One test is incomplete - as noted in previous review thread, test_single_gpu_partial_cast_vs_full computes amax_match, scale_match, and data_match but never asserts them, so the test will always pass regardless of correctness.

Confidence Score: 4/5

  • This PR is generally safe to merge with one test fix needed
  • The implementation is well-structured with thorough kernel documentation, proper validation checks, and performance optimizations. Numeric validation passed in GPT-3 training. The duplicate imports issue and docstring issue were already addressed per previous review. However, the missing test assertions reduce confidence slightly as the test cannot catch regressions.
  • tests/pytorch/distributed/test_cast_master_weights_to_fp8.py needs assertions added to test_single_gpu_partial_cast_vs_full

Important Files Changed

Filename Overview
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Test assertions missing in test_single_gpu_partial_cast_vs_full - computed match variables are never validated (noted in previous review)
transformer_engine/common/recipe/nvfp4.cu New NVFP4 kernels for partial cast, transpose, scale computation with optimized vectorized loads/stores and fused operations
transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp C++ wrappers for NVFP4 2D partial cast and multi-tensor operations with proper validation checks
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py NVFP4 tensor storage implementation with rowwise/columnwise data management and transpose support
transformer_engine/pytorch/tensor/utils.py Added NVFP4 support to quantize_master_weights with batched dtype conversion and multi-tensor dispatch pattern

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[FP32 Master Weights Shards] --> B[Batched Dtype Conversion]
    B --> C[Multi-Tensor Partial Amax Computation]
    C --> D[Per-Block Amax]
    C --> E[Global Amax]
    D --> F[AllReduce Per-Block Amax]
    E --> G[AllReduce Global Amax]
    F --> H[Compute Global Scale]
    G --> H
    H --> I[Fused Scale Kernel]
    I --> J[Per-Block Decode Scale]
    I --> K[Expanded FP8 Scale]
    I --> L[Copy Global Amax]
    J --> M[Multi-Tensor Partial Cast]
    K --> M
    H --> M
    M --> N[NVFP4 Rowwise Data with nibble packing]
    N --> O[Optional: NVFP4 Transpose]
    O --> P[NVFP4 Columnwise Data]
Loading

Last reviewed commit: 071709d

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:

Suggested change
manual_post_all_gather_processing=False,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535

Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense for now then. Let's change the default to True though since that's preferred.

I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:

cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)

y = model(x)  # Float8Tensor internally performs FP8 transpose

This is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:

cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)

y = model(x)  # FutureFormatTensor assumes data is interleaved

In this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.

Comment on lines 245 to +259
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
if self.weights_are_nvfp4:
weight_buffer_length = self.storage_total
buffer_rank_start = storage_rank_start
buffer_rank_end = storage_rank_end
else:
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
else:
weight_buffer_dtype = weights[0].dtype
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.

self.dp_group,
local_weights,
)
elif isinstance(first_weight, QuantizedTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's not logically correct to assume that QuantizedTensors are the same as FP8.

Suggested change
elif isinstance(first_weight, QuantizedTensor):
elif isinstance(first_weight, (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor)):

This pattern shows up a few other places in this test, and it comes from before this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

qiyuw and others added 2 commits February 20, 2026 05:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring says "FP8 primary weights for blockwise scaling" but should say "NVFP4 primary weights" to match the function purpose

Suggested change
r"""Helper function to cast master weights to NVFP4 primary weights.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as resolved.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator

/te-ci L1

timmoon10
timmoon10 previously approved these changes Feb 21, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.

@timmoon10 timmoon10 self-requested a review February 21, 2026 00:09
@timmoon10 timmoon10 dismissed their stale review February 21, 2026 00:09

Test failures

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants