Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Greptile SummaryThis PR adds NVFP4 primary weight support for distributed training with ZeRO/FSDP optimizers, enabling efficient 4-bit quantization with 2D block scaling. The implementation includes custom CUDA kernels for partial cast, transpose (with nibble repacking), and fused scale computation to minimize CPU overhead in large-scale training. Key Changes
Issues FoundOne test is incomplete - as noted in previous review thread, Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[FP32 Master Weights Shards] --> B[Batched Dtype Conversion]
B --> C[Multi-Tensor Partial Amax Computation]
C --> D[Per-Block Amax]
C --> E[Global Amax]
D --> F[AllReduce Per-Block Amax]
E --> G[AllReduce Global Amax]
F --> H[Compute Global Scale]
G --> H
H --> I[Fused Scale Kernel]
I --> J[Per-Block Decode Scale]
I --> K[Expanded FP8 Scale]
I --> L[Copy Global Amax]
J --> M[Multi-Tensor Partial Cast]
K --> M
H --> M
M --> N[NVFP4 Rowwise Data with nibble packing]
N --> O[Optional: NVFP4 Transpose]
O --> P[NVFP4 Columnwise Data]
Last reviewed commit: 071709d |
This comment was marked as outdated.
This comment was marked as outdated.
| start_offsets, | ||
| group, | ||
| fsdp_shard_model_weights=None, | ||
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535
There was a problem hiding this comment.
I see, that makes sense for now then. Let's change the default to True though since that's preferred.
I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:
cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)
y = model(x) # Float8Tensor internally performs FP8 transposeThis is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:
cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)
y = model(x) # FutureFormatTensor assumes data is interleavedIn this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| if isinstance(self.weights[0], QuantizedTensor): | ||
| weight_buffer_dtype = torch.uint8 | ||
| if self.weights_are_nvfp4: | ||
| weight_buffer_length = self.storage_total | ||
| buffer_rank_start = storage_rank_start | ||
| buffer_rank_end = storage_rank_end | ||
| else: | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end | ||
| else: | ||
| weight_buffer_dtype = weights[0].dtype | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end |
There was a problem hiding this comment.
Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.
| self.dp_group, | ||
| local_weights, | ||
| ) | ||
| elif isinstance(first_weight, QuantizedTensor): |
There was a problem hiding this comment.
Nit: It's not logically correct to assume that QuantizedTensors are the same as FP8.
| elif isinstance(first_weight, QuantizedTensor): | |
| elif isinstance(first_weight, (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor)): |
This pattern shows up a few other places in this test, and it comes from before this PR.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
| params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False | ||
| ): | ||
| r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. | ||
|
|
There was a problem hiding this comment.
docstring says "FP8 primary weights for blockwise scaling" but should say "NVFP4 primary weights" to match the function purpose
| r"""Helper function to cast master weights to NVFP4 primary weights. |
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: