diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 8a434b2148..69dd55a856 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -18,6 +18,7 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, MXFP8BlockScaling, Format, Recipe, @@ -26,13 +27,19 @@ from transformer_engine.pytorch import ( is_fp8_available, is_fp8_block_scaling_available, - is_mxfp8_available, + is_nvfp4_available, QuantizedTensor, Float8Tensor, Float8BlockwiseQTensor, + NVFP4Tensor, + is_mxfp8_available, MXFP8Tensor, ) -from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.utils import ( + quantize_master_weights, + cast_master_weights_to_fp8, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data @@ -67,6 +74,12 @@ def _get_raw_data(quantized_tensor, colwise=False): quantized_tensor._rowwise_data.dtype == torch.uint8 ), "Float8BlockwiseQTensor _rowwise_data must be uint8" return quantized_tensor._rowwise_data + elif isinstance(quantized_tensor, NVFP4Tensor): + assert hasattr(quantized_tensor, "_rowwise_data"), "NVFP4Tensor missing _rowwise_data" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "NVFP4Tensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data elif isinstance(quantized_tensor, MXFP8Tensor): if colwise: assert hasattr( @@ -135,22 +148,49 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.offsets = [0] for weight in self.weights: self.offsets.append(self.offsets[-1] + weight.numel()) - # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may # not be the end range of the last weight. if self.offsets[-1] % self.world_size != 0: self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + self.weights_are_nvfp4 = isinstance(self.weights[0], NVFP4Tensor) + + # Storage offsets operate on the packed representation (e.g., NVFP4 uint8 data). + self.storage_offsets = None + self.storage_sizes = None + self.storage_total = None + if self.weights_are_nvfp4: + self.storage_offsets = [0] + self.storage_sizes = [] + for weight in self.weights: + storage_size = _get_raw_data(weight).view(-1).numel() + self.storage_sizes.append(storage_size) + self.storage_offsets.append(self.storage_offsets[-1] + storage_size) + if self.storage_offsets[-1] % self.world_size != 0: + self.storage_offsets[-1] += ( + self.world_size - self.storage_offsets[-1] % self.world_size + ) + self.storage_total = self.storage_offsets[-1] + self.master_weights = [] # The start offset of the master weight in the weight self.start_offsets = [] # The overlapping area of the weight and this rank's local buffer self.overlapping_areas = [] + # Storage equivalents (only populated for NVFP4 tensors). + self.storage_start_offsets = [None] * len(self.weights) + self.storage_overlapping_areas = [None] * len(self.weights) # The start and end of this rank's local buffer in the global buffer rank_start = self.offsets[-1] // self.world_size * self.rank rank_end = rank_start + self.offsets[-1] // self.world_size + # Needed for NVFP4 tensors which packs two values per byte. + storage_rank_start = None + storage_rank_end = None + if self.weights_are_nvfp4: + storage_rank_start = self.storage_total // self.world_size * self.rank + storage_rank_end = storage_rank_start + self.storage_total // self.world_size for weight, offset in zip(self.weights, self.offsets[:-1]): if offset >= rank_end or (offset + weight.numel()) <= rank_start: # This weight is not in this rank's local buffer @@ -178,6 +218,20 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.start_offsets.append(start_offset) self.overlapping_areas.append(overlapping_area) + if self.weights_are_nvfp4: + for idx, (weight, storage_offset, storage_size) in enumerate( + zip(self.weights, self.storage_offsets[:-1], self.storage_sizes) + ): + if ( + storage_offset >= storage_rank_end + or (storage_offset + storage_size) <= storage_rank_start + ): + continue + overlap_start = max(storage_rank_start, storage_offset) + overlap_end = min(storage_rank_end, storage_offset + storage_size) + self.storage_start_offsets[idx] = overlap_start - storage_offset + self.storage_overlapping_areas[idx] = (overlap_start, overlap_end) + # Create global buffer for grads reduce-scatter self.grad_buffer = torch.empty( [self.offsets[-1]], dtype=torch.float32, device=weights[0].device @@ -187,12 +241,23 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals # Create global buffer for weights all-gather if isinstance(self.weights[0], QuantizedTensor): weight_buffer_dtype = torch.uint8 + if self.weights_are_nvfp4: + weight_buffer_length = self.storage_total + buffer_rank_start = storage_rank_start + buffer_rank_end = storage_rank_end + else: + weight_buffer_length = self.offsets[-1] + buffer_rank_start = rank_start + buffer_rank_end = rank_end else: weight_buffer_dtype = weights[0].dtype + weight_buffer_length = self.offsets[-1] + buffer_rank_start = rank_start + buffer_rank_end = rank_end self.weight_buffer = torch.empty( - [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + [weight_buffer_length], dtype=weight_buffer_dtype, device=weights[0].device ) - self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + self.weight_buffer_slice = self.weight_buffer[buffer_rank_start:buffer_rank_end] def step(self): # ----------------------------------------------------------------------------------------- @@ -231,10 +296,19 @@ def step(self): # ----------------------------------------------------------------------------------------- # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight # ----------------------------------------------------------------------------------------- - if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + for weight in self.weights: + assert isinstance(weight, NVFP4Tensor) + quantize_master_weights( + self.weights, + self.master_weights, + self.start_offsets, + self.dp_group, + ) + elif isinstance(first_weight, (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor)): + for weight in self.weights: + assert isinstance(weight, QuantizedTensor) cast_master_weights_to_fp8( self.weights, self.master_weights, @@ -253,20 +327,31 @@ def step(self): end = start_offset + master_weight.numel() weight.data.view(-1)[start:end].copy_(master_weight) + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- colwise_list = [False] if isinstance(self.weights[0], MXFP8Tensor): colwise_list.append(True) for colwise in colwise_list: - # ------------------------------------------------------------------------------------- - # Step 5: Copy the updated weights (not all weights) to the weight buffer - # ------------------------------------------------------------------------------------- for i in range(len(self.weights)): master_weight = self.master_weights[i] if master_weight is None: continue start_offset = self.start_offsets[i] - if isinstance(self.weights[i], QuantizedTensor): + if isinstance(self.weights[i], NVFP4Tensor): + storage_start = self.storage_start_offsets[i] + storage_overlap = self.storage_overlapping_areas[i] + if storage_start is None or storage_overlap is None: + continue + weight = _get_raw_data(self.weights[i]).view(-1) + storage_len = storage_overlap[1] - storage_overlap[0] + weight_slice = weight[storage_start : storage_start + storage_len] + overlapping_start, overlapping_end = storage_overlap + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + continue + elif isinstance(self.weights[i], QuantizedTensor): weight = _get_raw_data(self.weights[i], colwise) else: weight = self.weights[i] @@ -284,12 +369,22 @@ def step(self): # ------------------------------------------------------------------------------------- # Step 7: Copy the gathered weights from weight buffer to the actual weights # ------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - weight = _get_raw_data(weight, colwise) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) + if self.weights_are_nvfp4: + # NVFP4: use storage offsets (packs 2 values per byte) + for weight, storage_offset, storage_size in zip( + self.weights, self.storage_offsets[:-1], self.storage_sizes + ): + start = storage_offset + end = storage_offset + storage_size + raw_data = _get_raw_data(weight) + raw_data.view(-1).data.copy_(self.weight_buffer[start:end]) + else: + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight, colwise) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) if self.manual_post_all_gather_processing: quantized_weights = [ @@ -464,8 +559,23 @@ def step(self): # Update the master weight using gradient descent master_weight -= grad * self.lr - # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], QuantizedTensor): + # Step 3: Cast master weights to quantized or BF16 precision + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + local_weights.append(local_weight) + quantize_master_weights( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + elif isinstance(first_weight, QuantizedTensor): local_weights = [] for i, local_weight in enumerate(self.local_weights): if self.flatten_columnwise is not None: @@ -730,6 +840,90 @@ def _test_fsdp_cast_master_weights_to_fp8( ), f"Loss mismatch at rank {rank}, step {i} for {quantization} (FSDP)" +def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processing): + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + # Disable stochastic rounding for deterministic gradients + nvfp4_recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + + with te.quantized_model_init( + enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True + ): + model_nvfp4 = nn.Sequential( + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + # Create model with bf16 weights + model = nn.Sequential( + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + high_precision_init_val = w_nvfp4.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad = torch.zeros_like(w_nvfp4, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_nvfp4 = MiniZero_1( + [w for w in model_nvfp4.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for i in range(500): + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(2048, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + x = inputs[rank] + + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y_nvfp4 = model_nvfp4(x) + + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + target = targets[rank] + loss_nvfp4 = nn.MSELoss()(y_nvfp4, target) + loss = nn.MSELoss()(y, target) + + loss_nvfp4.backward() + loss.backward() + + optimizer.step() + optimizer_nvfp4.step() + + torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) + + def run_parallel_tests() -> None: """Run parallel tests""" @@ -762,13 +956,17 @@ def run_parallel_tests() -> None: quantizations.append("mxfp8") manual_post_all_gather_processings = [False, True] - + print("starting mini optimizer test") _test_mini_optimizer(dp_group) - + print("starting cast master weights to fp8 test") for quantization in quantizations: for post_ag_processing in manual_post_all_gather_processings: _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + nvfp4_available, _ = is_nvfp4_available(return_reason=True) + if nvfp4_available: + print("starting cast master weights to nvfp4 test") + _test_cast_master_weights_to_nvfp4(dp_group, False) dist.destroy_process_group() @@ -803,5 +1001,246 @@ def main() -> None: run_parallel_tests() +@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4 transpose test requires CUDA.") +def test_nvfp4_transpose_kernel() -> None: + """Test that nvfp4_transpose kernel produces bitwise identical results to reference.""" + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + shape = (2048, 5120) + master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + print("\n=== Testing NVFP4 transpose kernel ===") + + # Create reference with both rowwise and columnwise data + quantizer_with_colwise = NVFP4Quantizer( + rowwise=True, columnwise=True, with_2d_quantization=True + ) + reference_tensor = quantizer_with_colwise(master_weight.to(torch.bfloat16)) + assert reference_tensor._columnwise_data is not None, "Reference should have columnwise data" + assert ( + reference_tensor._columnwise_scale_inv is not None + ), "Reference should have columnwise scale_inv" + reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() + reference_columnwise_scale_inv = reference_tensor._columnwise_scale_inv.detach().clone() + reference_columnwise_amax = ( + reference_tensor._amax_columnwise.detach().clone() + if reference_tensor._amax_columnwise is not None + else None + ) + + # Create tensor with only rowwise data, then call _create_columnwise() + quantizer_rowwise_only = NVFP4Quantizer( + rowwise=True, columnwise=False, with_2d_quantization=True + ) + test_tensor = quantizer_rowwise_only(master_weight.to(torch.bfloat16)) + assert test_tensor._columnwise_data is None, "Test tensor should not have columnwise data yet" + + # Now call _create_columnwise() which uses our nvfp4_transpose kernel + test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) + assert ( + test_tensor._columnwise_data is not None + ), "Test tensor should have columnwise data after _create_columnwise()" + assert ( + test_tensor._columnwise_scale_inv is not None + ), "Test tensor should have columnwise scale_inv after _create_columnwise()" + + # Compare columnwise data - should be bitwise identical + torch.testing.assert_close( + test_tensor._columnwise_data, + reference_columnwise_data, + atol=0, + rtol=0, + msg="NVFP4 transpose kernel produced different columnwise data than reference!", + ) + + torch.testing.assert_close( + test_tensor._columnwise_scale_inv, + reference_columnwise_scale_inv, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise scale_inv than reference!", + ) + + torch.testing.assert_close( + test_tensor._amax_columnwise, + reference_columnwise_amax, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4 partial-cast test requires CUDA.") +def test_nvfp4_partial_cast_matches_full() -> None: + """Test multi-GPU partial cast: split master weight, partial cast on each rank, all-gather, compare.""" + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + dp_group = dist.new_group(backend="nccl") + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + # Shape must be divisible by WORLD_SIZE for even splitting + # Also ensure dimensions are multiples of 16 for NVFP4 tiles + shape = (4096, 4096) + total_elements = shape[0] * shape[1] + assert total_elements % WORLD_SIZE == 0, "Total elements must be divisible by WORLD_SIZE" + + # Full master weight (same on all ranks due to same seed) + full_master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + # Create reference using full quantization + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + reference_tensor = quantizer(full_master_weight.to(torch.bfloat16)) + reference_data = reference_tensor._rowwise_data.detach().clone() + reference_scale = reference_tensor._rowwise_scale_inv.detach().clone() + reference_amax = reference_tensor._amax_rowwise.detach().clone() + + # Split master weight evenly across ranks + shard_size = total_elements // WORLD_SIZE + start_offset = WORLD_RANK * shard_size + end_offset = start_offset + shard_size + master_weight_shard = full_master_weight.view(-1)[start_offset:end_offset].clone() + + # Create empty NVFP4 tensor for this rank (full shape, but we'll only fill our shard) + nvfp4_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) + nvfp4_tensor._rowwise_data.zero_() + nvfp4_tensor._rowwise_scale_inv.zero_() + if nvfp4_tensor._amax_rowwise is not None: + nvfp4_tensor._amax_rowwise.zero_() + + # Partial cast on each rank's shard + quantize_master_weights( + [nvfp4_tensor], + [master_weight_shard], + [start_offset], + dp_group, + ) + + # All-gather the rowwise data (packed FP4 bytes) + # Each rank has the full tensor but only its shard is filled + # We need to all-gather the shards + rowwise_data_flat = nvfp4_tensor._rowwise_data.view(-1) + + # For NVFP4, 2 elements are packed per byte, so byte shard size is shard_size // 2 + byte_shard_size = shard_size // 2 + byte_start = WORLD_RANK * byte_shard_size + byte_end = byte_start + byte_shard_size + my_shard_bytes = rowwise_data_flat[byte_start:byte_end].contiguous() + + # Gather all shards + gathered_shards = [torch.empty_like(my_shard_bytes) for _ in range(WORLD_SIZE)] + dist.all_gather(gathered_shards, my_shard_bytes, group=dp_group) + + # Reconstruct the full rowwise data + gathered_data = torch.cat(gathered_shards, dim=0).view(reference_data.shape) + + # Compare with reference + torch.testing.assert_close( + gathered_data, + reference_data, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Gathered rowwise data does not match reference!", + ) + + # Also verify scale matches (scale should be identical on all ranks after all-reduce) + torch.testing.assert_close( + nvfp4_tensor._rowwise_scale_inv, + reference_scale, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Scale does not match reference!", + ) + + # Verify amax matches + torch.testing.assert_close( + nvfp4_tensor._amax_rowwise, + reference_amax, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Amax does not match reference!", + ) + + +def test_single_gpu_partial_cast_vs_full(): + """ + Single GPU test: compare quantize_master_weights (offset=0) vs quantizer(). + This isolates whether the issue is in our manual Python scale computation or elsewhere. + """ + import math + import os + from transformer_engine.pytorch.tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.utils import quantize_master_weights + import transformer_engine_torch as tex + + torch.manual_seed(1234) + device = torch.device("cuda") + + # Test with same shape as the optimizer test + shape = (2048, 2048) + + # Create BF16 master weight + master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) + + # === Reference: Use NVFP4Quantizer directly === + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + ref = quantizer(master_weight) + ref_data = ref._rowwise_data.clone() + ref_scale = ref._rowwise_scale_inv.clone() + ref_amax = ref._amax_rowwise.clone() + + # === Test: Use quantize_master_weights with offset=0 (full tensor) === + # Create empty NVFP4 tensor + test_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) + test_tensor._rowwise_data.zero_() + test_tensor._rowwise_scale_inv.zero_() + if test_tensor._amax_rowwise is not None: + test_tensor._amax_rowwise.zero_() + + # Create a mock distributed group for single GPU + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://", rank=0, world_size=1) + mock_group = dist.new_group(ranks=[0]) + + quantize_master_weights( + [test_tensor], + [master_weight.view(-1)], # Flatten as expected + [0], # offset=0 means full tensor + mock_group, + ) + + # Compare amax + amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) + + # Compare scale + scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) + + # Compare data + data_match = torch.equal(test_tensor._rowwise_data, ref_data) + + if __name__ == "__main__": main() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 4579c51e9f..a105a0343f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -163,7 +163,6 @@ list(APPEND transformer_engine_cuda_sources recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu - recipe/nvfp4.cu comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources @@ -182,6 +181,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu + recipe/nvfp4.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index c0cec8a3b9..cad27a2992 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -309,6 +309,118 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, cudaStream_t stream); +/*! \brief Compute tile-level amax for a partial shard of a 2D tensor. + * + * For NVFP4 2D quantization with 16x16 tiles. Computes the maximum absolute + * value within each tile, but only for elements in [start_offset, start_offset + len) + * of the flattened tensor. Used in distributed settings where each rank owns a shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] amax Output amax buffer [tile_rows, tile_cols], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] amax_stride_h Stride for amax in tile-row dimension. + * \param[in] amax_stride_w Stride for amax in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, cudaStream_t stream); + +/*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization. + * + * Quantizes elements in [start_offset, start_offset + len) of the flattened tensor + * using precomputed per-tile scales. Each 16x16 tile uses its own scale factor. + * Used in distributed settings where each rank casts its owned shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] out Output NVFP4 packed tensor (2 values per byte). + * \param[in] scale Per-tile scale factors [tile_rows, tile_cols], float32. + * \param[in] global_scale Global scale factor [1], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] scale_stride_h Stride for scale in tile-row dimension. + * \param[in] scale_stride_w Stride for scale in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + const NVTETensor global_scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream); + +/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast. + * + * Each tile row's scale is repeated block_len times in the output. + * + * \param[in] input Input tensor with tile scales [tile_rows, tile_cols], float32. + * \param[out] output Output tensor with expanded scales [rows_padded, tile_cols], uint8 (E4M3). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Padded row count in output. + * \param[in] block_len Block length (typically 16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream); + +/*! \brief Compute per-block decode scale from block amax and global amax. + * + * Computes: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[out] scale Output scale tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax tensor (single element), float32. Avoids D2H transfer. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + const NVTETensor global_amax, cudaStream_t stream); + +/*! \brief Fused kernel for NVFP4 scale computation. + * + * Fuses three operations into one kernel: + * 1. Compute per-block decode scales from block amax and global amax + * 2. Copy global amax to target tensor + * 3. Expand tile-level scales to row-level and convert to FP8 E4M3 + * + * Saves 2 kernel launches per parameter. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax tensor [1], float32. + * \param[out] per_block_scale Output per-block scale [tile_rows, tile_cols], float32 (for partial_cast). + * \param[out] target_scale Output scale tensor [rows_padded, tile_cols], uint8 (E4M3). + * \param[out] target_amax Output amax tensor [1], float32 (copy of global_amax). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Total padded rows in output. + * \param[in] block_len Block length (16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream); + +/*! \brief Compute global encode scale from global amax. + * + * Computes: global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * If global_amax <= 0, returns 1.0. + * + * \param[in] global_amax Input global amax tensor [num_params], float32. + * \param[out] global_scale Output global scale tensor [num_params], float32. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 5f9a8fe149..659a48d97d 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -326,6 +326,32 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_in */ void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Transpose NVFP4 packed data. + * + * Unlike FP8, NVFP4 packs two 4-bit values per byte. This function correctly + * handles the nibble repacking during transpose. + * + * \param[in] input Input tensor with packed FP4 data. Shape: [M, K/2] bytes. + * \param[out] output Output tensor with transposed packed data. Shape: [K, M/2] bytes. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_data_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Transpose NVFP4 tile-level scales from rowwise to columnwise format. + * + * Takes rowwise_scale_inv where scales are stored at every 16th row (tile boundaries) + * and produces columnwise_scale_inv where scales are repeated 16 times per tile row. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * \param[in] input Input tensor with rowwise scales [M_padded, K_tiles], uint8 (E4M3). + * \param[out] output Output tensor with columnwise scales [K_padded, M_tiles], uint8 (E4M3). + * \param[in] M_tiles Number of tiles in M dimension. + * \param[in] K_tiles Number of tiles in K dimension. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 682d8b53f5..36ce60eaa5 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -5,17 +5,69 @@ ************************************************************************/ #include +#include #include +#include #include "../common.h" +#include "../util/ptx.cuh" #include "../utils.cuh" namespace transformer_engine { namespace nvfp4_recipe { +/* + * --------------------------------------------------------------------------- + * NVFP4 2D PARTIAL-SHARD KERNEL DESIGN + * + * These kernels mirror the FP8 block-scaling helpers but operate on shard-local + * slices and nibble-packed FP4 rowwise buffers. One CUDA block covers a logical + * 16x16 tile (grid = ceil(W/16) x ceil(H/16), blockDim = 256 threads). + * + * 1) Partial Amax (`nvfp4_2d_compute_partial_amax_kernel`) + * - Warps sweep the tile using nested loops, accumulating local maxima only + * for elements in [start_offset, start_offset + len). + * - Shared memory reduces the 8 warp maxima; the block writes a float into + * `amax_ptr[tile_row * stride_h + tile_col * stride_w]`. + * + * Tile/warp mapping (each '#' = elements visited by that warp): + * + * +------------------+ + * |########..........| Warp 0 + * |########..........| Warp 1 + * | ... | + * |########..........| Warp 7 + * +------------------+ + * + * 2) Partial Cast (`nvfp4_2d_partial_cast_kernel`) + * - Stage the tile into shared memory (same pattern as FP8). + * - For each 4-value group, build float2 pairs and call + * `ptx::mul_cvt_fp32_to_fp4_4x`, producing packed FP4 nibbles. + * - Compute a shard-local byte index and update only the owned nibble(s) + * using read-modify-write: + * + * packed_bits = [mw3 | mw2 | mw1 | mw0] + * byte_idx = (ref_elem_idx - start_offset) >> 1 + * if elem_idx % 2 == 0: // low nibble + * byte = (byte & 0xF0) | nibble + * else: // high nibble + * byte = (byte & 0x0F) | (nibble << 4) + * + * Thread coverage inside a tile: + * + * rows: 16 columns: 16 + * Warp 0 -> rows 0-1 lanes sweep cols 0..3, 4..7, ... + * Warp 1 -> rows 2-3 (groups of 4 elements per thread) + * ... + * Warp 7 -> rows 14-15 + * --------------------------------------------------------------------------- + */ + // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); +constexpr int kTileDim = 16; +constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, @@ -24,9 +76,804 @@ __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const floa *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_compute_partial_amax_kernel(const IType *input, float *amax_ptr, + const size_t amax_stride_h, const size_t amax_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + constexpr int kThreadsPerWarp = 32; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + static_assert(kTileDim * kTileDim == kThreadsPerBlock); + + const size_t tile_col = blockIdx.x; + const size_t tile_row = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + __shared__ float smem[kNumWarps]; + float amax = 0.0f; + + size_t r = tile_row * kTileDim + threadIdx.x / kTileDim; + size_t c = tile_col * kTileDim + threadIdx.x % kTileDim; + size_t idx = r * w + c; + if (r < h && c < w && idx >= start_offset && idx < end_offset) { + amax = fabs(static_cast(input_minus_offset[idx])); + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + + if (threadIdx.x % kThreadsPerWarp == 0) { + smem[threadIdx.x / kThreadsPerWarp] = amax; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + for (int i = 0; i < kNumWarps; ++i) { + float other_amax = smem[i]; + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax; + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *decode_scale_ptr, + const size_t scale_stride_h, const size_t scale_stride_w, + const float *global_scale_ptr, const size_t h, const size_t w, + const size_t start_offset, const size_t len) { + constexpr int kNumOutputElemsPerBank = 4; + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kRowsPerWarp = (kTileDim + kNumWarps - 1) / kNumWarps; + + __shared__ float smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + + const int tile_w = blockIdx.x; + const int tile_h = blockIdx.y; + const size_t shard_end = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + float global_encode_scale = global_scale_ptr[0]; + if (global_encode_scale <= 0.f) { + global_encode_scale = 1.f; + } + const float global_decode_scale = 1.0f / global_encode_scale; + + float tile_decode_scale = decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + tile_decode_scale = static_cast(static_cast(tile_decode_scale)); + constexpr float kFp32Max = 3.402823466e+38F; + float tile_encode_val = + (tile_decode_scale > 0.f) ? 1.0f / (tile_decode_scale * global_decode_scale) : kFp32Max; + tile_encode_val = fminf(tile_encode_val, kFp32Max); + const float2 scale_vec = make_float2(tile_encode_val, tile_encode_val); + + bool skip_store = true; + for (int i = 0; i < kRowsPerWarp; ++i) { + for (int j = 0; j < kLoopsPerRow; ++j) { + const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j; + if (h_in_smem >= kTileDim || w_in_smem >= kTileDim) { + continue; + } + const int h_in_input = tile_h * kTileDim + h_in_smem; + const int w_in_input = tile_w * kTileDim + w_in_smem; + const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; + if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && + idx_in_input < shard_end) { + smem[h_in_smem][w_in_smem] = static_cast(input_minus_offset[idx_in_input]); + skip_store = false; + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + bool other = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); + skip_store = skip_store && other; + } + skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); + if (skip_store) { + return; + } + + for (int i = 0; i < kRowsPerWarp; ++i) { + const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int row_in_output = tile_h * kTileDim + row_in_smem; + if (row_in_output >= h) { + continue; + } + const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank; + if (col_in_smem >= kTileDim) { + continue; + } + const int col_in_output = tile_w * kTileDim + col_in_smem; + + float vals[kNumOutputElemsPerBank]; + bool mask[kNumOutputElemsPerBank]; + size_t elem_idx[kNumOutputElemsPerBank]; + bool any_valid = false; + + for (int j = 0; j < kNumOutputElemsPerBank; ++j) { + const int col = col_in_output + j; + const bool in_width = col < w; + const size_t idx = static_cast(row_in_output) * w + col; + elem_idx[j] = idx; + const bool in_shard = in_width && idx >= start_offset && idx < shard_end; + mask[j] = in_shard; + const bool in_tile = (col_in_smem + j) < kTileDim; + const float tile_val = in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; + vals[j] = in_shard ? tile_val : 0.0f; + any_valid |= in_shard; + } + + if (!any_valid) { + continue; + } + + const float2 in01 = make_float2(vals[0], vals[1]); + const float2 in23 = make_float2(vals[2], vals[3]); + const auto packed = + transformer_engine::ptx::mul_cvt_fp32_to_fp4_4x(in01, in23, scale_vec, 0); + const uint16_t packed_bits = reinterpret_cast(packed); + + for (int pair = 0; pair < 2; ++pair) { + const int first = pair * 2; + const int second = first + 1; + if (!mask[first] && !mask[second]) { + continue; + } + const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second]; + const size_t byte_idx = (ref_idx - start_offset) >> 1; + uint8_t byte = output[byte_idx]; + + if (mask[first]) { + const uint8_t nibble = static_cast((packed_bits >> (4 * first)) & 0xF); + if ((elem_idx[first] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } + } + + if (mask[second]) { + const uint8_t nibble = static_cast((packed_bits >> (4 * second)) & 0xF); + if ((elem_idx[second] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } + } + + output[byte_idx] = byte; + } + } +} + +void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + nvfp4_2d_compute_partial_amax_kernel<<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, + start_offset, len);) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, + const Tensor global_scale, size_t h, size_t w, size_t scale_stride_h, + size_t scale_stride_w, size_t start_offset, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); + NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, len);)) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 TRANSPOSE KERNEL + * + * Unlike FP8, NVFP4 packs two 4-bit values into each byte. A simple byte-wise + * transpose doesn't work because the packing changes: + * - Before transpose: elements [m, 2c] and [m, 2c+1] share a byte + * - After transpose: elements [k, 2*m_packed] and [k, 2*m_packed+1] share a byte + * which were originally [2*m_packed, k] and [2*m_packed+1, k] + * --------------------------------------------------------------------------- + */ + +// Vectorized transpose kernel parameters +constexpr int TRANSPOSE_TILE_DIM = 64; // Logical FP4 elements per tile dimension +constexpr int TRANSPOSE_TILE_PACKED = 32; // TILE_DIM / 2 bytes +constexpr int TRANSPOSE_BLOCK_SIZE = 256; // threads per block + +// Shared memory: store unpacked 4-bit values as bytes for easy transpose +// Size: TILE_DIM x (TILE_DIM + 4) to avoid bank conflicts +constexpr int TRANSPOSE_SHMEM_STRIDE = TRANSPOSE_TILE_DIM + 4; + +/* + * Vectorized transpose kernel with uint2 loads/stores (256 threads) + * Tile: 64x64 logical FP4 = 64x32 packed bytes + */ +__global__ void __launch_bounds__(TRANSPOSE_BLOCK_SIZE) + nvfp4_transpose_kernel(const uint8_t *__restrict__ input, uint8_t *__restrict__ output, + const size_t M, const size_t K) { + const size_t K_packed = K / 2; + const size_t M_packed = M / 2; + + const size_t tile_m_start = blockIdx.x * TRANSPOSE_TILE_DIM; + const size_t tile_k_start = blockIdx.y * TRANSPOSE_TILE_DIM; + + __shared__ uint8_t shmem[TRANSPOSE_TILE_DIM][TRANSPOSE_SHMEM_STRIDE]; + + const int tid = threadIdx.x; + + // Phase 1: Load input tile with VECTORIZED uint2 reads + // 256 threads, each loads 8 bytes (uint2) = 2048 bytes total + // Input tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // 64 rows, 4 threads per row + const int thread_col = (tid % 4) * 8; // 4 x 8 = 32 bytes per row + + const size_t global_m = tile_m_start + thread_row; + const size_t global_k_packed_base = tile_k_start / 2 + thread_col; + + // Load 8 bytes as uint2 + uint2 loaded = make_uint2(0, 0); + if (global_m < M && global_k_packed_base + 7 < K_packed) { + loaded = *reinterpret_cast(&input[global_m * K_packed + global_k_packed_base]); + } else if (global_m < M) { + // Boundary: scalar loads + uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll + for (int b = 0; b < 8; ++b) { + size_t col = global_k_packed_base + b; + bytes[b] = (col < K_packed) ? input[global_m * K_packed + col] : 0; + } + } + + // Unpack 8 bytes -> 16 nibbles and store to shared memory + const uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll + for (int b = 0; b < 8; ++b) { + const int k0 = thread_col * 2 + b * 2; + const int k1 = k0 + 1; + shmem[thread_row][k0] = bytes[b] & 0x0F; + shmem[thread_row][k1] = (bytes[b] >> 4) & 0x0F; + } + } + + __syncthreads(); + + // Phase 2: Write output with VECTORIZED uint2 stores + // Output tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // output K dimension [0, 64) + const int thread_col_base = (tid % 4) * 8; // output M_packed [0, 32) in steps of 8 + + const size_t global_k = tile_k_start + thread_row; + const size_t global_m_packed_base = tile_m_start / 2 + thread_col_base; + + if (global_k >= K) return; + + // Build 8 output bytes in registers + uint8_t out_bytes[8]; + +#pragma unroll + for (int b = 0; b < 8; ++b) { + const int out_m_packed = thread_col_base + b; + + if (global_m_packed_base + b >= M_packed) { + out_bytes[b] = 0; + continue; + } + + // Two M positions that pack into this output byte + const int m0 = out_m_packed * 2; + const int m1 = out_m_packed * 2 + 1; + const int k = thread_row; + + // Read from shared memory (transposed access) + const uint8_t val0 = shmem[m0][k]; + const uint8_t val1 = shmem[m1][k]; + + out_bytes[b] = val0 | (val1 << 4); + } + + // Vectorized store as uint2 + if (global_m_packed_base + 7 < M_packed) { + *reinterpret_cast(&output[global_k * M_packed + global_m_packed_base]) = + *reinterpret_cast(out_bytes); + } else { + // Boundary: scalar stores + for (int b = 0; b < 8 && global_m_packed_base + b < M_packed; ++b) { + output[global_k * M_packed + global_m_packed_base + b] = out_bytes[b]; + } + } + } +} + +void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { + // Input has logical shape [M, K], stored as [M, K/2] bytes + // Output has logical shape [K, M], stored as [K, M/2] bytes + + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 transpose input must be uint8."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 transpose output must be uint8."); + + // Get dimensions from packed storage + // input.shape() = [M, K/2], so M = shape[0], K = shape[1] * 2 + const auto in_shape = input.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", + in_shape.size(), "D."); + const size_t M = in_shape[0]; + const size_t K_packed = in_shape[1]; + const size_t K = K_packed * 2; + + // Output should be [K, M/2] + const size_t M_packed = M / 2; + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + const auto out_shape = output.shape(); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 transpose expects 2D output."); + NVTE_CHECK(out_shape[0] == K && out_shape[1] == M_packed, + "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, "], got [", + out_shape[0], ", ", out_shape[1], "]."); + + if (M == 0 || K == 0) return; + + // Use vectorized kernel (faster than TMA for pure transpose) + // 128x128 tiles with 512 threads and uint4 vectorized access + dim3 block(TRANSPOSE_BLOCK_SIZE); + dim3 grid((M + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM, + (K + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM); + + nvfp4_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M, K); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE TRANSPOSE KERNEL + * + * Transposes tile-level scales from rowwise to columnwise format. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * Input (rowwise_scale_inv): [M_padded, K_tiles] where scales are stored + * at every 16th row (i.e., row 0, 16, 32, ... contain the actual scales, + * and each row i within a tile block has the same scale as row (i // 16) * 16). + * + * Output (columnwise_scale_inv): [K_padded, M_tiles] where scales are + * repeated 16 times per tile row. + * + * Mapping: + * output[k_tile * 16 + i, m_tile] = input[m_tile * 16, k_tile] + * for i in [0, 16) and valid (k_tile, m_tile) indices. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_scale_transpose_kernel( + const uint8_t *__restrict__ input, // [M_padded, K_tiles], E4M3 stored as uint8 + uint8_t *__restrict__ output, // [K_padded, M_tiles], E4M3 stored as uint8 + const size_t M_tiles, // Number of M tiles + const size_t K_tiles, // Number of K tiles + const size_t input_stride, // K_tiles (input row stride) + const size_t output_stride, // M_tiles (output row stride) + const size_t K_padded // Output height +) { + // Each thread handles one output element + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= K_padded || out_col >= M_tiles) return; + + // Determine which tile row this belongs to + const size_t k_tile = out_row / kTileDim; + + // Read from input: row = m_tile * 16 (first row of the tile), col = k_tile + // m_tile = out_col + if (k_tile < K_tiles) { + const size_t in_row = out_col * kTileDim; // m_tile * 16 + const uint8_t scale = input[in_row * input_stride + k_tile]; + output[out_row * output_stride + out_col] = scale; + } else { + output[out_row * output_stride + out_col] = 0; + } +} + +void nvfp4_scale_transpose(const Tensor input, Tensor output, size_t M_tiles, size_t K_tiles, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + + const auto in_shape = input.shape(); + const auto out_shape = output.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + + const size_t input_stride = in_shape[1]; // K_tiles + const size_t output_stride = out_shape[1]; // M_tiles + const size_t K_padded = out_shape[0]; + + if (M_tiles == 0 || K_tiles == 0 || K_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((M_tiles + kBlockDim - 1) / kBlockDim, (K_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_scale_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M_tiles, K_tiles, input_stride, output_stride, + K_padded); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE EXPANSION KERNEL + * + * Expands tile-level scales to row-level scales and converts to FP8 E4M3, used in partial cast. + * + * Input (per_block_decode_scale): [tile_rows, tile_cols] in float32 + * Output (target_scale): [rows_padded, tile_cols] in uint8 (E4M3) + * + * Each tile row's scale is repeated block_len times in the output. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_expand_scale_to_fp8_kernel( + const float *__restrict__ input, // [tile_rows, tile_cols] + uint8_t *__restrict__ output, // [rows_padded, tile_cols] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + scale_val = input[tile_row * tile_cols + out_col]; + } + + // Convert float32 to FP8 E4M3 + // Clamp to FP8 E4M3 range and convert + fp8e4m3 fp8_val = static_cast(scale_val); + output[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kFloat32, "Scale input must be float32."); + NVTE_CHECK(output.dtype() == DType::kByte, "Scale output must be uint8 (E4M3)."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_expand_scale_to_fp8_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), tile_rows, tile_cols, rows_padded, block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 COMPUTE PER-BLOCK DECODE SCALE KERNEL + * + * Computes per-block decode scale from block amax and global amax: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * = block_amax * 448 / global_amax + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh + * + * Input (block_amax): [tile_rows, tile_cols] in float32 + * Input (global_amax): scalar float32 (per-tensor amax after all-reduce) + * Output (scale): [tile_rows, tile_cols] in float32 + * Output (global_scale_out): scalar float32 (the computed global encode scale) + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_compute_per_block_scale_kernel( + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + float *__restrict__ scale, // [tile_rows, tile_cols] + const float *__restrict__ global_amax_ptr, // Pointer to single float value (avoids D2H) + const size_t numel) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + // Read global_amax from device memory (avoids D2H transfer) + float global_amax = *global_amax_ptr; + + // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax + float safe_global_amax = fmaxf(global_amax, tiny); + float global_scale = + (global_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Compute per-block decode scale: S_dec_b = block_amax / fp4_max * S_enc + float amax_val = block_amax[idx]; + float result = fminf((amax_val / fp4_max) * global_scale, flt_max); + scale[idx] = result; +} + +// Simple kernel to compute global encode scale from global amax +__global__ void nvfp4_compute_global_scale_kernel( + const float *__restrict__ global_amax, // [num_params] + float *__restrict__ global_scale, // [num_params] + const size_t num_params) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_params) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + float amax = global_amax[idx]; + float safe_amax = fmaxf(amax, tiny); + float scale = (amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_amax, flt_max) : 1.0f; + global_scale[idx] = scale; +} + +void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, const Tensor global_amax, + cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + + size_t numel = block_amax.numel(); + if (numel == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (numel + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_per_block_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(scale.data.dptr), + reinterpret_cast(global_amax.data.dptr), numel); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, + cudaStream_t stream) { + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_scale.dtype() == DType::kFloat32, "Global scale must be float32."); + + size_t num_params = global_amax.numel(); + if (num_params == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (num_params + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_global_scale_kernel<<>>( + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(global_scale.data.dptr), num_params); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * FUSED NVFP4 SCALE COMPUTATION KERNEL + * + * Fuses three operations into one kernel: + * 1. nvfp4_compute_per_block_scale: compute tile-level decode scales from block amax + * 2. target_amax.copy_: copy global amax to target tensor + * 3. nvfp4_expand_scale_to_fp8: expand to row-level and convert to FP8 E4M3 + * + * Input (block_amax): [tile_rows, tile_cols] float32 + * Input (global_amax): [1] float32 + * Output (per_block_scale): [tile_rows, tile_cols] float32 (intermediate, for partial_cast) + * Output (target_scale): [rows_padded, tile_cols] uint8 (E4M3) + * Output (target_amax): [1] float32 (copy of global_amax) + * + * Saves 2 kernel launches per parameter (eliminates nvfp4_compute_per_block_scale and + * nvfp4_expand_scale_to_fp8 as separate calls, plus the amax copy). + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_fused_scale_kernel( + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + const float *__restrict__ global_amax, // [1] + float *__restrict__ per_block_scale, // [tile_rows, tile_cols] - for partial_cast + uint8_t *__restrict__ target_scale, // [rows_padded, tile_cols] + float *__restrict__ target_amax, // [1] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Read global amax once per thread (broadcast) + const float g_amax = *global_amax; + + // Thread (0,0) copies global_amax to target_amax + if (out_row == 0 && out_col == 0) { + *target_amax = g_amax; + } + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + // Compute the scale value + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + float safe_global_amax = fmaxf(g_amax, tiny); + float global_scale = + (g_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Read block amax and compute per-block decode scale + float amax_val = block_amax[tile_row * tile_cols + out_col]; + scale_val = fminf((amax_val / fp4_max) * global_scale, flt_max); + + // Write per-block scale (only once per tile, when out_row % block_len == 0) + if (out_row % block_len == 0) { + per_block_scale[tile_row * tile_cols + out_col] = scale_val; + } + } + + // Convert float32 to FP8 E4M3 and write expanded scale + fp8e4m3 fp8_val = static_cast(scale_val); + target_scale[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, Tensor per_block_scale, + Tensor target_scale, Tensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.dtype() == DType::kFloat32, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.dtype() == DType::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.dtype() == DType::kFloat32, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_fused_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(per_block_scale.data.dptr), + reinterpret_cast(target_scale.data.dptr), + reinterpret_cast(target_amax.data.dptr), tile_rows, tile_cols, rows_padded, + block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_expand_scale_to_fp8); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_expand_scale_to_fp8(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), tile_rows, tile_cols, + rows_padded, block_len, stream); +} + +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + const NVTETensor global_amax, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_amax), stream); +} + +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_global_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_global_scale(*convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(global_scale), stream); +} + +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_scale_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_scale_transpose(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), M_tiles, K_tiles, stream); +} + +void nvte_nvfp4_data_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_data_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + stream); +} + +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_compute_partial_amax); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_compute_partial_amax(*convertNVTETensorCheck(inp), + *convertNVTETensorCheck(amax), h, w, amax_stride_h, + amax_stride_w, start_offset, block_len, stream); +} + +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + const NVTETensor global_scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_partial_cast(*convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_scale), h, w, scale_stride_h, + scale_stride_w, start_offset, block_len, stream); +} + void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, @@ -52,3 +899,15 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_fused_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_fused_scale( + *convertNVTETensorCheck(block_amax), *convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(per_block_scale), *convertNVTETensorCheck(target_scale), + *convertNVTETensorCheck(target_amax), tile_rows, tile_cols, rows_padded, block_len, stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0e91071983..3d0aa7fba4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -157,6 +157,39 @@ std::optional> te_general_grouped_gemm( at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); +at::Tensor nvfp4_data_transpose(at::Tensor input, std::optional output = std::nullopt); + +void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, + int64_t K_tiles); + +void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, std::vector K_list); + +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len); + +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax); + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); + +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len); + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); /*************************************************************************************************** @@ -342,6 +375,19 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const size_t h, size_t w, size_t start_offset, size_t block_len, const DType out_dtype); +void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len); + +void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, + const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len); + +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len); void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, at::Tensor amax_colwise, int rows, int cols, size_t start_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp new file mode 100644 index 0000000000..6e1eb4483c --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -0,0 +1,157 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +void nvfp4_2d_compute_partial_amax(const at::Tensor& tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); + TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); + TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || + tensor.scalar_type() == at::ScalarType::BFloat16, + "tensor must be a float or bfloat16 tensor"); + + const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(amax); + + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), + amax.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_2d_partial_cast(const at::Tensor& inp, py::handle out, const at::Tensor& scale, + const at::Tensor& global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), global_scale_cu.data(), + h, w, scale.stride(0), scale.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = inp_list.size(); + TORCH_CHECK(out_list.size() == num_tensors, "out_list size mismatch"); + TORCH_CHECK(scale_list.size() == num_tensors, "scale_list size mismatch"); + TORCH_CHECK(global_scale_list.size() == num_tensors, "global_scale_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& inp = inp_list[i]; + const auto& out = out_list[i]; + const auto& scale = scale_list[i]; + const auto& global_scale = global_scale_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), + global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), + start_offset, static_cast(block_len), stream); + } +} + +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = master_weight_list.size(); + TORCH_CHECK(partial_amax_list.size() == num_tensors, "partial_amax_list size mismatch"); + TORCH_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& master_weight = master_weight_list[i]; + auto& partial_amax = partial_amax_list[i]; + auto& global_amax = global_amax_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(partial_amax.dim() == 2, "partial_amax must be a 2D tensor"); + TORCH_CHECK(partial_amax.scalar_type() == at::ScalarType::Float, + "partial_amax must be a float tensor"); + TORCH_CHECK(master_weight.scalar_type() == at::ScalarType::Float || + master_weight.scalar_type() == at::ScalarType::BFloat16, + "master_weight must be a float or bfloat16 tensor"); + TORCH_CHECK(global_amax.scalar_type() == at::ScalarType::Float, + "global_amax must be a float tensor"); + TORCH_CHECK(global_amax.numel() == 1, "global_amax must have exactly one element"); + + // Compute partial amax (per-block amax) + const TensorWrapper tensor_cu = makeTransformerEngineTensor(master_weight.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(partial_amax); + + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, + partial_amax.stride(0), partial_amax.stride(1), start_offset, + static_cast(block_len), stream); + + // Compute global amax + auto* global_amax_ptr = global_amax.data_ptr(); + TensorWrapper fake_te_output( + /*dptr=*/nullptr, tensor_cu.shape(), DType::kFloat32, global_amax_ptr); + + nvte_compute_amax(tensor_cu.data(), fake_te_output.data(), stream); + } +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 14f32c7b93..5675397644 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -260,6 +260,44 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("nvfp4_data_transpose", &transformer_engine::pytorch::nvfp4_data_transpose, + "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), + py::arg("out"), py::call_guard()); + m.def( + "nvfp4_2d_scale_transpose", &transformer_engine::pytorch::nvfp4_2d_scale_transpose, + "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", + py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), + py::call_guard()); + m.def("nvfp4_expand_scale_to_fp8", &transformer_engine::pytorch::nvfp4_expand_scale_to_fp8, + "Expand tile-level scales to row-level scales and convert to FP8 E4M3", py::arg("input"), + py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), + py::arg("block_len"), py::call_guard()); + m.def("nvfp4_compute_per_block_scale", + &transformer_engine::pytorch::nvfp4_compute_per_block_scale, + "Compute per-block decode scale from block amax and global amax", py::arg("block_amax"), + py::arg("scale"), py::arg("global_amax"), py::call_guard()); + m.def("nvfp4_compute_global_scale", &transformer_engine::pytorch::nvfp4_compute_global_scale, + "Compute global encode scale from global amax", py::arg("global_amax"), + py::arg("global_scale"), py::call_guard()); + m.def("nvfp4_fused_scale", &transformer_engine::pytorch::nvfp4_fused_scale, + "Fused kernel: compute per-block decode scale, copy global amax, expand to row-level FP8", + py::arg("block_amax"), py::arg("global_amax"), py::arg("per_block_scale"), + py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), + py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); + m.def("nvfp4_multi_tensor_fused_scale", + &transformer_engine::pytorch::nvfp4_multi_tensor_fused_scale, + "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for " + "multiple tensors", + py::arg("block_amax_list"), py::arg("global_amax_list"), py::arg("per_block_scale_list"), + py::arg("target_scale_list"), py::arg("target_amax_list"), py::arg("tile_rows_list"), + py::arg("tile_cols_list"), py::arg("rows_padded_list"), py::arg("block_len"), + py::call_guard()); + m.def("nvfp4_2d_multi_tensor_transpose", + &transformer_engine::pytorch::nvfp4_2d_multi_tensor_transpose, + "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", + py::arg("rowwise_data_list"), py::arg("columnwise_data_list"), + py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), py::arg("M_list"), + py::arg("K_list"), py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); @@ -282,6 +320,29 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D + m.def("nvfp4_2d_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, + "Compute partial amax from master weights for NVFP4 2D", py::arg("tensor"), py::arg("amax"), + py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); + m.def("nvfp4_multi_tensor_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_multi_tensor_compute_partial_amax, + "Batched compute partial and global amax from master weights for NVFP4 2D", + py::arg("master_weight_list"), py::arg("partial_amax_list"), py::arg("global_amax_list"), + py::arg("h_list"), py::arg("w_list"), py::arg("start_offset_list"), + py::arg("block_len") = 16, py::call_guard()); + m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, + "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), + py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), + py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); + m.def("nvfp4_multi_tensor_2d_partial_cast", + &transformer_engine::pytorch::nvfp4_multi_tensor_2d_partial_cast, + "Batched partial cast from master weights for NVFP4 2D", py::arg("inp_list"), + py::arg("out_list"), py::arg("scale_list"), py::arg("global_scale_list"), py::arg("h_list"), + py::arg("w_list"), py::arg("start_offset_list"), py::arg("block_len") = 16, + py::call_guard()); m.def("mxfp8_scaling_compute_partial_amax", &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, "Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"), diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 477d7c87e7..aaa27a104a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -5,6 +5,8 @@ ************************************************************************/ #include +#include +#include #include #include @@ -52,11 +54,218 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { + init_extension(); + + // Input is packed FP4: logical [M, K] stored as [M, K/2] bytes + // Output is packed FP4: logical [K, M] stored as [K, M/2] bytes + const auto shape = getTensorShape(input); + NVTE_CHECK(shape.size() == 2, "NVFP4 transpose expects 2D input (packed storage)."); + + const size_t M = shape[0]; + const size_t K_packed = shape[1]; + const size_t K = K_packed * 2; // logical K + const size_t M_packed = M / 2; + + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + // Output shape: [K, M/2] + std::vector output_shape = {static_cast(K), static_cast(M_packed)}; + + // Output tensor + at::Tensor out; + if (output.has_value()) { + out = *output; + NVTE_CHECK( + static_cast(out.size(0)) == K && static_cast(out.size(1)) == M_packed, + "Output shape mismatch for NVFP4 transpose."); + } else { + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(output_shape, opts); + } + + // Return immediately if tensor is empty + if (M == 0 || K == 0) { + return out; + } + + // Call the NVFP4 transpose kernel + auto input_cu = + makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, DType::kByte); + auto output_cu = + makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, DType::kByte); + nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return out; +} + +void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, + int64_t K_tiles) { + init_extension(); + + // Input: rowwise_scale_inv [M_padded, K_tiles], uint8 (E4M3 stored as bytes) + // Output: columnwise_scale_inv [K_padded, M_tiles], uint8 (E4M3 stored as bytes) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.scalar_type() == at::kByte, + "NVFP4 scale transpose output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kByte); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), static_cast(M_tiles), + static_cast(K_tiles), at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { + init_extension(); + + // Input: per_block_decode_scale [tile_rows, tile_cols], float32 + // Output: target_scale [rows_padded, tile_cols], uint8 (E4M3) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 expand scale expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 expand scale expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kFloat, "NVFP4 expand scale input must be float32."); + NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 expand scale output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kFloat32); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), static_cast(tile_rows), + static_cast(tile_cols), static_cast(rows_padded), + static_cast(block_len), at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, + at::Tensor global_amax) { + init_extension(); + + // block_amax and scale: [tile_rows, tile_cols], float32 + // global_amax: single element tensor, float32 (avoids D2H transfer) + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(scale.scalar_type() == at::kFloat, "Scale must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto scale_cu = makeTransformerEngineTensor(scale); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + + nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), global_amax_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { + init_extension(); + + // block_amax: [tile_rows, tile_cols], float32 + // global_amax: [1], float32 + // per_block_scale: [tile_rows, tile_cols], float32 (for partial_cast) + // target_scale: [rows_padded, tile_cols], uint8 (E4M3) + // target_amax: [1], float32 + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), + static_cast(tile_rows), static_cast(tile_cols), + static_cast(rows_padded), static_cast(block_len), + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len) { + init_extension(); + + const size_t num_tensors = block_amax_list.size(); + NVTE_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + NVTE_CHECK(per_block_scale_list.size() == num_tensors, "per_block_scale_list size mismatch"); + NVTE_CHECK(target_scale_list.size() == num_tensors, "target_scale_list size mismatch"); + NVTE_CHECK(target_amax_list.size() == num_tensors, "target_amax_list size mismatch"); + NVTE_CHECK(tile_rows_list.size() == num_tensors, "tile_rows_list size mismatch"); + NVTE_CHECK(tile_cols_list.size() == num_tensors, "tile_cols_list size mismatch"); + NVTE_CHECK(rows_padded_list.size() == num_tensors, "rows_padded_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& block_amax = block_amax_list[i]; + const auto& global_amax = global_amax_list[i]; + auto& per_block_scale = per_block_scale_list[i]; + auto& target_scale = target_scale_list[i]; + auto& target_amax = target_amax_list[i]; + const size_t tile_rows = static_cast(tile_rows_list[i]); + const size_t tile_cols = static_cast(tile_cols_list[i]); + const size_t rows_padded = static_cast(rows_padded_list[i]); + + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), tile_rows, tile_cols, + rows_padded, static_cast(block_len), stream); + } +} + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { + init_extension(); + + // global_amax and global_scale: [num_params], float32 + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_scale.scalar_type() == at::kFloat, "Global scale must be float32."); + + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_compute_global_scale(global_amax_cu.data(), global_scale_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); // Make sure input is contiguous - const auto &input = tensor.contiguous(); + const auto& input = tensor.contiguous(); // Allocate output tensor if needed if (!out) { @@ -77,5 +286,70 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { return std::move(*out); } +void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, std::vector K_list) { + init_extension(); + + const size_t num_tensors = rowwise_data_list.size(); + NVTE_CHECK(columnwise_data_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(rowwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(columnwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(M_list.size() == num_tensors, "M_list size mismatch"); + NVTE_CHECK(K_list.size() == num_tensors, "K_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + // Process each tensor - the main benefit is reduced Python overhead + // by doing the iteration in C++ rather than Python + constexpr size_t TILE_SIZE = 16; + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& rowwise_data = rowwise_data_list[i]; + auto& columnwise_data = columnwise_data_list[i]; + const auto& rowwise_scale_inv = rowwise_scale_inv_list[i]; + auto& columnwise_scale_inv = columnwise_scale_inv_list[i]; + const int64_t M = M_list[i]; + const int64_t K = K_list[i]; + + // Transpose data: [M, K/2] -> [K, M/2] + const auto data_shape = getTensorShape(rowwise_data); + NVTE_CHECK(data_shape.size() == 2, "NVFP4 data must be 2D."); + const size_t M_packed = static_cast(M) / 2; + const size_t K_packed = data_shape[1]; + + auto input_cu = makeTransformerEngineTensor( + rowwise_data.data_ptr(), std::vector{static_cast(M), K_packed}, + DType::kByte); + auto output_cu = makeTransformerEngineTensor( + columnwise_data.data_ptr(), std::vector{static_cast(K), M_packed}, + DType::kByte); + nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), stream); + + // Transpose scales + const size_t M_tiles = (static_cast(M) + TILE_SIZE - 1) / TILE_SIZE; + const size_t K_tiles = (static_cast(K) + TILE_SIZE - 1) / TILE_SIZE; + + const auto scale_in_shape = getTensorShape(rowwise_scale_inv); + const auto scale_out_shape = getTensorShape(columnwise_scale_inv); + + auto scale_input_cu = makeTransformerEngineTensor( + rowwise_scale_inv.data_ptr(), std::vector{scale_in_shape[0], scale_in_shape[1]}, + DType::kByte); + auto scale_output_cu = makeTransformerEngineTensor( + columnwise_scale_inv.data_ptr(), + std::vector{scale_out_shape[0], scale_out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), M_tiles, K_tiles, + stream); + } +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 8be23d0c19..daf14ab305 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -316,6 +316,17 @@ def update_usage( if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None + # If both rowwise and columnwise are requested, create columnwise from rowwise if needed + if rowwise_usage and columnwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._amax_rowwise is not None + ), "Cannot update to rowwise and columnwise usage because rowwise data is None." + if self._columnwise_data is None or self._columnwise_scale_inv is None: + self._create_columnwise() + return + # Update row-scaled data if rowwise_usage: if self._rowwise_data is None: @@ -356,3 +367,51 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None + + def _create_columnwise(self): + """ + Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. + """ + assert ( + self._quantizer is not None and self._quantizer.with_2d_quantization + ), "Cannot create columnwise data without 2D quantization enabled." + rowwise_data = self._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + # NVFP4 requires a specialized transpose that handles nibble repacking + self._columnwise_data = tex.nvfp4_data_transpose(rowwise_data, out=self._columnwise_data) + if self._columnwise_scale_inv is None: + assert self._quantizer is not None + # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) + # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 + logical_shape = self.size() + columnwise_scale_inv_shape = self._quantizer.get_scale_shape(logical_shape, True) + self._columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=self._rowwise_scale_inv.dtype, + device=self._rowwise_scale_inv.device, + ) + assert len(self._rowwise_scale_inv.shape) == 2 + assert len(self._columnwise_scale_inv.shape) == 2 + + # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale + # is repeated 16 times (once per row in the 16x16 tile). + # columnwise_scale_inv has shape [K_padded, M_tiles] where scales are + # repeated 16 times per tile row. + TILE_SIZE = 16 + logical_shape = self.size() + M, K = logical_shape[0], logical_shape[-1] + M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE + + tex.nvfp4_2d_scale_transpose( + self._rowwise_scale_inv, + self._columnwise_scale_inv, + M_tiles, + K_tiles, + ) + + # Also set columnwise amax (same as rowwise since it's just transposed data) + if self._amax_columnwise is None: + self._amax_columnwise = torch.empty_like(self._amax_rowwise) + self._amax_columnwise.copy_(self._amax_rowwise) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 05e2d22e9c..31699458d2 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -2,9 +2,10 @@ # # See LICENSE for license information. -"""Helper functions for using fp8 tensors as weights""" +"""Helper functions for using fp8/nvfp4 tensors as weights""" from typing import Optional, Union, List +import math import torch import transformer_engine_torch as tex @@ -16,10 +17,12 @@ from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..utils import is_non_tn_fp8_gemm_supported +from ..constants import NVFP4_BLOCK_SCALING_SIZE def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -45,13 +48,19 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): new_raw_data.detach().copy_(old_raw_data) tensor._rowwise_data = new_raw_data del old_raw_data + elif isinstance(tensor, NVFP4Tensor): + old_rowwise = tensor._rowwise_data + assert old_rowwise.dtype == new_raw_data.dtype, "The data types of raw data don't match" + new_raw_data.detach().copy_(old_rowwise) + tensor._rowwise_data = new_raw_data + del old_rowwise elif isinstance(tensor, MXFP8Tensor): raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") else: raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") -def cast_master_weights_to_fp8( +def quantize_master_weights( model_weights, master_weights, start_offsets, @@ -59,15 +68,15 @@ def cast_master_weights_to_fp8( fsdp_shard_model_weights=None, manual_post_all_gather_processing=False, ): - r"""Helper function to cast master weights to FP8 primary weights. + r"""Helper function to cast master weights to quantized (FP8/NVFP4) primary weights. This is intended for use with ZeRO/FSDP. Each rank has a shard of the master weights (possibly empty) and a full copy of the model - weights. + weights. Supports FP8 (delayed, current, blockwise, MXFP8) and NVFP4 quantization. Parameters ---------- - model_weights : list of FP8 weights. + model_weights : list of quantized weights (FP8 or NVFP4). master_weights : list of master weights. Typically they are FP32 weights. start_offsets : list of integers, the starting index of the master weight in the model weight. master_weight may be smaller than model_weight because it could be distributed @@ -90,6 +99,7 @@ def cast_master_weights_to_fp8( current_scaling_params = [] blockwise_scaling_params = [] mxfp8_scaling_params = [] + nvfp4_params = [] if fsdp_shard_model_weights is None: use_fsdp_shard_model_weights = False @@ -97,6 +107,46 @@ def cast_master_weights_to_fp8( else: use_fsdp_shard_model_weights = True + # Batch convert master_weights to model dtype for NVFP4 (single kernel instead of N kernels) + # Check if there are any NVFP4 weights + has_nvfp4 = any( + isinstance(w._get_quantizer(), NVFP4Quantizer) + for w in model_weights + if hasattr(w, "_get_quantizer") + ) + if has_nvfp4 and len(model_weights) > 0: + # Find target dtype from first NVFP4 weight + target_dtype = None + for w in model_weights: + if hasattr(w, "_get_quantizer") and isinstance(w._get_quantizer(), NVFP4Quantizer): + target_dtype = w.dtype + break + + if target_dtype is not None: + # Collect non-None master_weights and their indices + non_none_indices = [] + non_none_weights = [] + sizes = [] + for i, mw in enumerate(master_weights): + if mw is not None: + non_none_indices.append(i) + non_none_weights.append(mw.view(-1)) + sizes.append(mw.numel()) + + if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: + # Concatenate, convert once, then split + concatenated = torch.cat(non_none_weights) + converted = concatenated.to(target_dtype) + split_weights = torch.split(converted, sizes) + + # Rebuild master_weights list with converted tensors + converted_master_weights = list(master_weights) + for idx, split_w, orig_mw in zip( + non_none_indices, split_weights, [master_weights[i] for i in non_none_indices] + ): + converted_master_weights[idx] = split_w.view(orig_mw.shape) + master_weights = converted_master_weights + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): @@ -115,34 +165,42 @@ def cast_master_weights_to_fp8( if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() - if master_weight is not None: - # When not using fp8_primary_weights, the master_weight (fp32) is first cast to - # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when - # fp8_primary_weights is enabled, we still keep this logic to keep numerical - # consistency. So here we cast the master_weight to model_weight.dtype. - master_weight = master_weight.to(model_weight.dtype) - quantizer = model_weight._get_quantizer() - if isinstance(quantizer, Float8Quantizer): - delayed_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8CurrentScalingQuantizer): - current_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8BlockQuantizer): - blockwise_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, MXFP8Quantizer): - mxfp8_scaling_params.append( + + if isinstance(quantizer, NVFP4Quantizer): + # NVFP4: master_weight dtype conversion already done above + nvfp4_params.append( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) else: - raise ValueError( - f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" - ) + # FP8: convert master_weight to model dtype + if master_weight is not None: + # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + + if isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8BlockQuantizer): + blockwise_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + mxfp8_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + else: + raise ValueError( + f"quantize_master_weights for {type(quantizer)} is not supported yet" + ) extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: @@ -153,6 +211,32 @@ def cast_master_weights_to_fp8( _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args) if len(mxfp8_scaling_params) > 0: _cast_master_weights_to_fp8_mxfp8_scaling(mxfp8_scaling_params, *extra_args) + if len(nvfp4_params) > 0: + _cast_master_weights_to_nvfp4_2d(nvfp4_params, *extra_args) + + +def cast_master_weights_to_fp8( + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False, +): + r"""Helper function to cast master weights to FP8 primary weights. + + .. deprecated:: + Use :func:`quantize_master_weights` instead. + + """ + quantize_master_weights( + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights, + manual_post_all_gather_processing, + ) def _cast_master_weights_to_fp8_delayed_scaling( @@ -474,6 +558,220 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) +def _cast_master_weights_to_nvfp4_2d( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): + r"""Helper function to cast master weights to NVFP4 2D quantized weights. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + device = params[0][0].device + block_len = NVFP4_BLOCK_SCALING_SIZE + + cu_amax_sizes = [0] + tile_shapes: List[tuple[int, int]] = [] + row_sizes: List[int] = [] + tile_widths: List[int] = [] + scale_targets: List[torch.Tensor] = [] + amax_targets: List[Optional[torch.Tensor]] = [] + for model_weight, _, _, _ in params: + quantizer = model_weight._get_quantizer() + assert isinstance(quantizer, NVFP4Quantizer) + assert quantizer.with_2d_quantization, "NVFP4 2D quantization must be enabled." + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + tile_h = (h + block_len - 1) // block_len + tile_w = (w + block_len - 1) // block_len + tile_shapes.append((tile_h, tile_w)) + row_sizes.append(h) + tile_widths.append(tile_w) + scale_targets.append(model_weight._rowwise_scale_inv) + amax_targets.append(model_weight._amax_rowwise) + num_amaxes = tile_h * tile_w + cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes) + + packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) + packed_scales = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) + + amaxes: List[torch.Tensor] = [] + scales: List[torch.Tensor] = [] + global_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) + global_amax_views: List[torch.Tensor] = [global_amaxes[i : i + 1] for i in range(len(params))] + + # Collect tensors for batched multi-tensor amax computation + master_weight_list: List[torch.Tensor] = [] + partial_amax_list: List[torch.Tensor] = [] + global_amax_list: List[torch.Tensor] = [] + h_list: List[int] = [] + w_list: List[int] = [] + start_offset_list: List[int] = [] + + for i, (model_weight, master_weight, start_offset, _) in enumerate(params): + scale_shape = tile_shapes[i] + amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) + scale = packed_scales[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) + global_amax_view = global_amax_views[i] + + assert model_weight._rowwise_scale_inv is not None + + amaxes.append(amax) + scales.append(scale) + + if master_weight is not None and master_weight.numel() > 0: + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + # Collect for batched processing + master_weight_list.append(master_weight) + partial_amax_list.append(amax) + global_amax_list.append(global_amax_view) + h_list.append(h) + w_list.append(w) + start_offset_list.append(start_offset) + + # Batched multi-tensor call for partial and global amax computation + if master_weight_list: + tex.nvfp4_multi_tensor_compute_partial_amax( + master_weight_list, + partial_amax_list, + global_amax_list, + h_list, + w_list, + start_offset_list, + block_len, + ) + + if packed_amaxes.numel() > 0: + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + if global_amaxes.numel() > 0: + torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # Use GPU kernel to compute global encode scales from global amaxes + # This replaces multiple Python tensor operations with a single kernel + global_scale_tensor = torch.empty_like(global_amaxes) + + tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) + global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] + + # Collect tensors for batched fused scale kernel + fused_scale_block_amax_list: List[torch.Tensor] = [] + fused_scale_global_amax_list: List[torch.Tensor] = [] + fused_scale_per_block_scale_list: List[torch.Tensor] = [] + fused_scale_target_scale_list: List[torch.Tensor] = [] + fused_scale_target_amax_list: List[torch.Tensor] = [] + fused_scale_tile_rows_list: List[int] = [] + fused_scale_tile_cols_list: List[int] = [] + fused_scale_rows_padded_list: List[int] = [] + + # Collect tensors for batched partial cast kernel + partial_cast_inp_list: List[torch.Tensor] = [] + partial_cast_out_list: List[torch.Tensor] = [] + partial_cast_scale_list: List[torch.Tensor] = [] + partial_cast_global_scale_list: List[torch.Tensor] = [] + partial_cast_h_list: List[int] = [] + partial_cast_w_list: List[int] = [] + partial_cast_start_offset_list: List[int] = [] + + # First pass: collect all tensors and update usage + zipped_meta = zip( + tile_shapes, + row_sizes, + tile_widths, + scale_targets, + amax_targets, + params, + amaxes, + scales, + global_scale_views, + ) + for idx, ( + tile_shape, + rows, + tile_col_cnt, + target_scale, + target_amax, + (model_weight, master_weight, start_offset, model_weight_fragment), + block_amax, + per_block_decode_scale, + global_scale, + ) in enumerate(zipped_meta): + + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + + tile_rows = tile_shape[0] + rows_padded = target_scale.shape[0] + global_amax_view = global_amaxes[idx : idx + 1] + + # Collect for fused scale kernel (only if target_amax is not None) + if target_amax is not None: + fused_scale_block_amax_list.append(block_amax) + fused_scale_global_amax_list.append(global_amax_view) + fused_scale_per_block_scale_list.append(per_block_decode_scale) + fused_scale_target_scale_list.append(target_scale) + fused_scale_target_amax_list.append(target_amax) + fused_scale_tile_rows_list.append(tile_rows) + fused_scale_tile_cols_list.append(tile_col_cnt) + fused_scale_rows_padded_list.append(rows_padded) + + # Collect for partial cast kernel (only for layers owned by this rank) + if master_weight is not None and master_weight.numel() > 0: + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + rowwise_bytes = model_weight._rowwise_data.view(-1) + byte_start = start_offset // 2 + byte_end = (end_offset + 1) // 2 + model_weight_fragment = rowwise_bytes[byte_start:byte_end] + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + + partial_cast_inp_list.append(master_weight) + partial_cast_out_list.append(model_weight_fragment) + partial_cast_scale_list.append(per_block_decode_scale) + partial_cast_global_scale_list.append(global_scale) + partial_cast_h_list.append(h) + partial_cast_w_list.append(w) + partial_cast_start_offset_list.append(start_offset) + + # Batched multi-tensor call for fused scale + if fused_scale_block_amax_list: + tex.nvfp4_multi_tensor_fused_scale( + fused_scale_block_amax_list, + fused_scale_global_amax_list, + fused_scale_per_block_scale_list, + fused_scale_target_scale_list, + fused_scale_target_amax_list, + fused_scale_tile_rows_list, + fused_scale_tile_cols_list, + fused_scale_rows_padded_list, + block_len, + ) + + # Batched multi-tensor call for partial cast + if partial_cast_inp_list: + tex.nvfp4_multi_tensor_2d_partial_cast( + partial_cast_inp_list, + partial_cast_out_list, + partial_cast_scale_list, + partial_cast_global_scale_list, + partial_cast_h_list, + partial_cast_w_list, + partial_cast_start_offset_list, + block_len, + ) + + def _cast_master_weights_to_fp8_mxfp8_scaling( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): # pylint: disable=unused-argument @@ -605,9 +903,15 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten - Float8Tensor: may need to create a transposed view to match backend GEMM. - Float8BlockwiseQTensor: create column-wise storage. - Plain pytorch tensor: noop. + + For NVFP4 tensors, uses batched multi-tensor processing to reduce CPU overhead. """ if not isinstance(model_weights, list): model_weights = [model_weights] + + # Collect NVFP4 tensors for batched processing + nvfp4_tensors = [] + for model_weight in model_weights: if isinstance(model_weight, Float8Tensor): # Delayed scaling and per-tensor current scaling: if backend does not support @@ -617,12 +921,94 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten elif isinstance(model_weight, Float8BlockwiseQTensor): # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() + elif isinstance(model_weight, NVFP4Tensor): + # Collect for batched processing + nvfp4_tensors.append(model_weight) elif isinstance(model_weight, MXFP8Tensor): # MXFP8 scaling: no need to do anything. pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") + # Batch process all NVFP4 tensors with multi-tensor approach + if nvfp4_tensors: + _nvfp4_2d_multi_tensor_transpose(nvfp4_tensors) + + +def _nvfp4_2d_multi_tensor_transpose(nvfp4_tensors: List[NVFP4Tensor]): + """ + Batched columnwise creation for multiple NVFP4 tensors. + Reduces CPU overhead by collecting all tensor metadata and dispatching to C++. + """ + TILE_SIZE = 16 + + # Prepare tensor lists for batched C++ call + rowwise_data_list = [] + columnwise_data_list = [] + rowwise_scale_inv_list = [] + columnwise_scale_inv_list = [] + M_list = [] + K_list = [] + + for tensor in nvfp4_tensors: + rowwise_data = tensor._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + tensor._rowwise_data = rowwise_data + + logical_shape = tensor.size() + M, K = logical_shape[0], logical_shape[-1] + M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE + + # Allocate columnwise_data if needed + if tensor._columnwise_data is None: + # Output shape: [K, M/2] packed bytes + columnwise_data = torch.empty( + (K, M // 2), + dtype=torch.uint8, + device=rowwise_data.device, + ) + tensor._columnwise_data = columnwise_data + else: + columnwise_data = tensor._columnwise_data + + # Allocate columnwise_scale_inv if needed + if tensor._columnwise_scale_inv is None: + assert tensor._quantizer is not None + columnwise_scale_inv_shape = tensor._quantizer.get_scale_shape(logical_shape, True) + columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=tensor._rowwise_scale_inv.dtype, + device=tensor._rowwise_scale_inv.device, + ) + tensor._columnwise_scale_inv = columnwise_scale_inv + else: + columnwise_scale_inv = tensor._columnwise_scale_inv + + rowwise_data_list.append(rowwise_data) + columnwise_data_list.append(columnwise_data) + rowwise_scale_inv_list.append(tensor._rowwise_scale_inv) + columnwise_scale_inv_list.append(columnwise_scale_inv) + M_list.append(M) + K_list.append(K) + + # Copy amax if needed + if tensor._amax_columnwise is None and tensor._amax_rowwise is not None: + tensor._amax_columnwise = tensor._amax_rowwise.clone() + elif tensor._amax_rowwise is not None: + tensor._amax_columnwise.copy_(tensor._amax_rowwise) + + # Dispatch to C++ multi-tensor kernel + tex.nvfp4_2d_multi_tensor_transpose( + rowwise_data_list, + columnwise_data_list, + rowwise_scale_inv_list, + columnwise_scale_inv_list, + M_list, + K_list, + ) + def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an object is custom.