From e5070210a972fe87b29e17048eebc9ea7b1297f9 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 3 Feb 2026 05:24:01 -0800 Subject: [PATCH 01/23] perf comparison between torch, fuser, and runtime TMA local matmul ops --- CMakeLists.txt | 19 +- csrc/runtime/matmul_tma.cu | 317 +++++++++++++++++++++++++++++++++ csrc/runtime/matmul_tma.h | 22 +++ tests/cpp/test_matmul_perf.cpp | 163 +++++++++++++++++ 4 files changed, 518 insertions(+), 3 deletions(-) create mode 100644 csrc/runtime/matmul_tma.cu create mode 100644 csrc/runtime/matmul_tma.h create mode 100644 tests/cpp/test_matmul_perf.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f5dad5c90d5..3db4afc1c10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -297,6 +297,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/runtime/fusion_cache_utils.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_executor_cache.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_kernel_runtime.cpp + ${NVFUSER_SRCS_DIR}/runtime/matmul_tma.cu ${NVFUSER_SRCS_DIR}/scheduler/cache_policy_refiner.cpp ${NVFUSER_SRCS_DIR}/scheduler/cutlass.cpp ${NVFUSER_SRCS_DIR}/scheduler/heuristic.cpp @@ -399,19 +400,27 @@ endif() # "private" (not installed) static library. add_library(codegen_internal OBJECT ${NVFUSER_SRCS}) +# Special handling for CUDA files that include CUTLASS headers +# nvcc doesn't support the same flags as gcc/clang, so we need to wrap them +set_source_files_properties( + ${NVFUSER_SRCS_DIR}/runtime/matmul_tma.cu + PROPERTIES + COMPILE_OPTIONS "-Xcompiler=-Wall;-Xcompiler=-Wno-unused-function;-Xcompiler=-Werror" +) + if(NOT MSVC) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") target_compile_options(codegen_internal PRIVATE - -Wall -Wno-unused-function -Werror + $<$:-Wall -Wno-unused-function -Werror # These warnings are not treated as errors because of gcc 12.2 used in # manylinux image. consider enable this when we upgrade. # linking comment: # https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266 - -Wno-error=restrict -Wno-error=stringop-overflow -Wno-error=maybe-uninitialized) + -Wno-error=restrict -Wno-error=stringop-overflow -Wno-error=maybe-uninitialized>) else() target_compile_options(codegen_internal PRIVATE - -Wall -Wno-unused-function -Werror) + $<$:-Wall -Wno-unused-function -Werror>) endif() endif() @@ -423,6 +432,9 @@ if (NVMMH_FOUND) endif() target_include_directories(codegen_internal SYSTEM PUBLIC ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ${NVFUSER_THIRD_PARTY_DIR}/cutlass/include + ${NVFUSER_THIRD_PARTY_DIR}/cutlass/tools/util/include + /usr/local/cuda/include/cccl PRIVATE ${CUDA_INCLUDE_DIRS} ) @@ -1066,6 +1078,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_translate_mma.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul_aten_evaluation.cpp + ${NVFUSER_ROOT}/tests/cpp/test_matmul_perf.cpp # ${NVFUSER_ROOT}/tests/cpp/test_matmul_sass.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul_scheduler.cpp ${NVFUSER_ROOT}/tests/cpp/test_mma.cpp diff --git a/csrc/runtime/matmul_tma.cu b/csrc/runtime/matmul_tma.cu new file mode 100644 index 00000000000..18393b000d1 --- /dev/null +++ b/csrc/runtime/matmul_tma.cu @@ -0,0 +1,317 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include + +#include +#include + +#if defined(NVFUSER_ENABLE_CUTLASS) +#if !defined(__CUDACC_VER_MAJOR__) +#define __CUDACC_VER_MAJOR__ 13 +#define __CUDACC_VER_MINOR__ 0 +#endif +#include "cutlass/arch/config.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +#endif + +namespace nvfuser { + +namespace { + +bool hasValidTmaInputShape(const at::Tensor& a, const at::Tensor& b) { + if (!a.defined() || !b.defined()) { + return false; + } + if (!a.is_cuda() || !b.is_cuda()) { + return false; + } + if (a.dim() != 2 || b.dim() != 2) { + return false; + } + if (a.scalar_type() != b.scalar_type()) { + return false; + } + if (!(a.scalar_type() == at::ScalarType::Half || + a.scalar_type() == at::ScalarType::BFloat16)) { + return false; + } + if (!a.is_contiguous() || !b.is_contiguous()) { + return false; + } + if (a.size(1) != b.size(0)) { + return false; + } + if (a.get_device() != b.get_device()) { + return false; + } + // CUTLASS TMA mainloop requires alignment-compatible K/N extents. + constexpr int64_t kAlignment = 8; + if (a.size(1) % kAlignment != 0 || b.size(1) % kAlignment != 0) { + return false; + } + return true; +} + +#if defined(NVFUSER_ENABLE_CUTLASS) && defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +using namespace cute; + +template +struct MatmulTmaSm90 { + using ElementA = ElementT; + using ElementB = ElementT; + using ElementC = ElementT; + using ElementD = ElementT; + + using LayoutATag = cutlass::layout::RowMajor; + using LayoutBTag = cutlass::layout::RowMajor; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + + static constexpr int kAlignmentA = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAlignmentB = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAlignmentC = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAlignmentD = + 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using MmaTileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _64>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutCTag, + kAlignmentC, + ElementD, + LayoutDTag, + kAlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + kAlignmentA, + ElementB, + LayoutBTag, + kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; +}; + +template +typename MatmulTmaSm90::Gemm::Arguments buildArguments( + at::Tensor& output, + const at::Tensor& a, + const at::Tensor& b, + int64_t m, + int64_t n, + int64_t k) { + using Config = MatmulTmaSm90; + using ElementA = typename Config::ElementA; + using ElementB = typename Config::ElementB; + using ElementD = typename Config::ElementD; + using StrideA = typename Config::StrideA; + using StrideB = typename Config::StrideB; + using StrideC = typename Config::StrideC; + using StrideD = typename Config::StrideD; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, {static_cast(m), static_cast(k), 1}); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, {static_cast(k), static_cast(n), 1}); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, {static_cast(m), static_cast(n), 1}); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, {static_cast(m), static_cast(n), 1}); + + typename Config::GemmKernel::MainloopArguments mainloop_args{ + static_cast(a.data_ptr()), + stride_a, + static_cast(b.data_ptr()), + stride_b}; + + typename Config::GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + stride_c, + static_cast(output.data_ptr()), + stride_d}; + + typename Config::GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + {static_cast(m), static_cast(n), static_cast(k), 1}, + mainloop_args, + epilogue_args}; + + return args; +} + +template +void runMatmulSm90( + at::Tensor& output, + const at::Tensor& a, + const at::Tensor& b, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + using Config = MatmulTmaSm90; + typename Config::Gemm gemm; + auto args = buildArguments(output, a, b, m, n, k); + + size_t workspace_size = Config::Gemm::get_workspace_size(args); + auto workspace_options = + at::TensorOptions().dtype(at::kByte).device(a.device()); + auto workspace = + at::empty({static_cast(workspace_size)}, workspace_options); + + auto can_implement_status = gemm.can_implement(args); + NVF_CHECK( + can_implement_status == cutlass::Status::kSuccess, + "TMA GEMM cannot be implemented for the given inputs."); + + auto status = gemm.initialize(args, workspace.data_ptr(), stream); + NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM."); + + status = gemm.run( + args, + workspace.data_ptr(), + stream, + /*cuda_adapter=*/nullptr, + /*launch_with_pdl=*/true); + NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM."); +} +#else +template +void runMatmulSm90( + at::Tensor& output, + const at::Tensor& a, + const at::Tensor& b, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + NVF_THROW("CUTLASS SM90 support is required for TMA matmul."); +} +#endif // NVFUSER_ENABLE_CUTLASS && CUTLASS_ARCH_MMA_SM90_SUPPORTED + +void validateInputs(const at::Tensor& a, const at::Tensor& b) { + NVF_CHECK(a.is_cuda(), "Expected CUDA tensor for operand A."); + NVF_CHECK(b.is_cuda(), "Expected CUDA tensor for operand B."); + NVF_CHECK(a.dim() == 2, "Operand A must be rank-2."); + NVF_CHECK(b.dim() == 2, "Operand B must be rank-2."); + NVF_CHECK( + a.scalar_type() == b.scalar_type(), + "Operands A and B must have the same dtype."); + NVF_CHECK( + a.scalar_type() == at::ScalarType::Half || + a.scalar_type() == at::ScalarType::BFloat16, + "Only Half and BFloat16 are supported."); + NVF_CHECK( + a.is_contiguous() && b.is_contiguous(), + "Operands must be contiguous row-major tensors."); + NVF_CHECK( + a.size(1) == b.size(0), + "Mismatched matmul dimensions: A[K] must match B[K]."); + NVF_CHECK( + a.get_device() == b.get_device(), + "Operands must be on the same CUDA device."); + + constexpr int64_t kAlignment = 8; + NVF_CHECK( + a.size(1) % kAlignment == 0, + "K dimension must be a multiple of 8 for TMA alignment."); + NVF_CHECK( + b.size(1) % kAlignment == 0, + "N dimension must be a multiple of 8 for TMA alignment."); +} + +} // namespace + +at::Tensor matmulTma(const at::Tensor& a, const at::Tensor& b) { + validateInputs(a, b); + at::cuda::CUDAGuard device_guard{a.device()}; + auto* props = at::cuda::getDeviceProperties(a.get_device()); + NVF_CHECK( + props->major >= 9, + "TMA matmul requires SM90 (Hopper) or newer."); + + const int64_t m = a.size(0); + const int64_t n = b.size(1); + const int64_t k = a.size(1); + + auto options = + at::TensorOptions().dtype(a.scalar_type()).device(a.device()); + at::Tensor output = at::empty({m, n}, options); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + +#if defined(NVFUSER_ENABLE_CUTLASS) + if (a.scalar_type() == at::ScalarType::Half) { + runMatmulSm90(output, a, b, m, n, k, stream); + } else { + runMatmulSm90(output, a, b, m, n, k, stream); + } +#else + NVF_THROW("CUTLASS support is required for TMA matmul."); +#endif + + return output; +} + +bool canRunMatmulTma(const at::Tensor& a, const at::Tensor& b) { + if (!hasValidTmaInputShape(a, b)) { + return false; + } + +#if !defined(NVFUSER_ENABLE_CUTLASS) || \ + !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + return false; +#else + auto* props = at::cuda::getDeviceProperties(a.get_device()); + return props->major >= 9; +#endif +} + +} // namespace nvfuser diff --git a/csrc/runtime/matmul_tma.h b/csrc/runtime/matmul_tma.h new file mode 100644 index 00000000000..d859dc55e95 --- /dev/null +++ b/csrc/runtime/matmul_tma.h @@ -0,0 +1,22 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +//! Run an SM90 TMA-based matmul (A[M,K] x B[K,N]) on the current CUDA stream. +//! Returns a new output tensor with the same dtype as the inputs. +at::Tensor matmulTma(const at::Tensor& a, const at::Tensor& b); + +//! Returns true when the handcrafted TMA matmul kernel can run for the inputs. +//! This is a non-throwing capability check intended for runtime dispatch. +bool canRunMatmulTma(const at::Tensor& a, const at::Tensor& b); + +} // namespace nvfuser diff --git a/tests/cpp/test_matmul_perf.cpp b/tests/cpp/test_matmul_perf.cpp new file mode 100644 index 00000000000..c656091090a --- /dev/null +++ b/tests/cpp/test_matmul_perf.cpp @@ -0,0 +1,163 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include + +#include "exceptions.h" +#include "fusion.h" +#include "fusion_guard.h" +#include "mma_type.h" +#include "ops/all_ops.h" +#include "preseg_passes/pre_segmenter.h" +#include "runtime/fusion_executor_cache.h" +#include "runtime/matmul_tma.h" +#include "scheduler/matmul.h" +#include "tests/cpp/utils.h" + +#if defined(NVFUSER_ENABLE_CUTLASS) +#if !defined(__CUDACC_VER_MAJOR__) +#define __CUDACC_VER_MAJOR__ 13 +#define __CUDACC_VER_MINOR__ 0 +#endif +#include "cutlass/arch/config.h" +#endif + +namespace nvfuser { + +namespace { + +struct MatmulProblem { + int64_t m; + int64_t n; + int64_t k; +}; + +std::unique_ptr buildMatmulFusion(DataType dtype) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto a = makeContigTensor(2, dtype); + auto b = makeContigTensor(2, dtype); + fusion->addInput(a); + fusion->addInput(b); + + auto layout = MmaLayout::TT; + auto a_canon = canonicalizeInputToBMNK(a, layout, MmaOperand::A); + auto b_canon = canonicalizeInputToBMNK(b, layout, MmaOperand::B); + auto c = fusedMultiplySum(a_canon, b_canon, {-1}); + auto d = castOp(dtype, c); + + fusion->addOutput(d); + OptimizationPass::runPass(fusion.get()); + return fusion; +} + +template +double timeMs(int warmup_iters, int iters, Fn&& fn) { + for (int i = 0; i < warmup_iters; ++i) { + fn(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + auto start = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < iters; ++i) { + fn(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration elapsed = end - start; + return elapsed.count() / static_cast(iters); +} + +void printResult( + const std::string& label, + const MatmulProblem& problem, + double ms_per_iter) { + const double flops = 2.0 * static_cast(problem.m) * + static_cast(problem.n) * static_cast(problem.k); + const double gflops = flops / (ms_per_iter * 1.0e6); + std::cout << label << " M=" << problem.m << " N=" << problem.n + << " K=" << problem.k << " : " << ms_per_iter << " ms, " + << gflops << " GFLOPs" << std::endl; +} + +} // namespace + +TEST(MatmulPerfTest, CompareImplementations) { + if (!deviceMajorMinorCheck(9, 0)) { + GTEST_SKIP() << "Requires SM90 (Hopper)."; + } +#if !defined(NVFUSER_ENABLE_CUTLASS) + GTEST_SKIP() << "CUTLASS support is disabled."; +#endif +#if defined(NVFUSER_ENABLE_CUTLASS) && !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + GTEST_SKIP() << "CUTLASS SM90 support is unavailable."; +#endif + + constexpr int warmup_iters = 100; + constexpr int iters = 1000; + + std::vector problems{ + {1024, 1024, 1024}, + {2048, 2048, 2048}, + {4096, 4096, 4096}, + }; + + std::vector dtypes{ + at::ScalarType::Half, + at::ScalarType::BFloat16, + }; + + for (auto dtype : dtypes) { + for (const auto& problem : problems) { + at::cuda::CUDAGuard device_guard{0}; + auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + auto a = at::randn({problem.m, problem.k}, options); + auto b = at::randn({problem.k, problem.n}, options); + + auto ref = at::matmul(a, b); + + auto tma_out = matmulTma(a, b); + EXPECT_TRUE(at::allclose(tma_out, ref, 1e-2, 1e-2)); + + auto fusion = buildMatmulFusion( + dtype == at::ScalarType::Half ? DataType::Half : DataType::BFloat16); + FusionExecutorCache executor_cache(std::move(fusion)); + auto fuser_out = executor_cache.runFusionWithInputs({a, b}); + auto fuser_tensor = fuser_out[0].as(); + EXPECT_TRUE(at::allclose(fuser_tensor, ref, 1e-2, 1e-2)); + + double torch_ms = timeMs(warmup_iters, iters, [&]() { + auto out = at::matmul(a, b); + (void)out; + }); + printResult("torch", problem, torch_ms); + + double fuser_ms = timeMs(warmup_iters, iters, [&]() { + auto out = executor_cache.runFusionWithInputs({a, b}); + (void)out; + }); + printResult("nvfuser", problem, fuser_ms); + + double tma_ms = timeMs(warmup_iters, iters, [&]() { + auto out = matmulTma(a, b); + (void)out; + }); + printResult("tma", problem, tma_ms); + } + } +} + +} // namespace nvfuser From 64650681c85c9b686f091783c3ec40483ca7148b Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 00:58:00 -0800 Subject: [PATCH 02/23] first fused comm/compute kernel for AG+GEMM --- CMakeLists.txt | 4 +- .../test_multidevice_fused_remote_matmul.cpp | 140 ++++++++++++++++++ ..._multidevice_fused_remote_matmul_kernel.cu | 92 ++++++++++++ 3 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/test_multidevice_fused_remote_matmul.cpp create mode 100644 tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 3db4afc1c10..f7c11037a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -931,7 +931,7 @@ function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK) if(NOT MSVC) target_compile_options(${TEST_NAME} PRIVATE - -Wall -Wno-unused-function -Werror + $<$:-Wall -Wno-unused-function -Werror> ) endif() endfunction() @@ -1031,6 +1031,8 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_fused_remote_matmul.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp new file mode 100644 index 00000000000..4ca715b8c91 --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -0,0 +1,140 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include +#include + +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +double timeFusedRemoteMatmulMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t world_size, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +using FusedRemoteMatmulTest = MultiDeviceTest; + +TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { + if (!communicator_->is_available()) { + GTEST_SKIP() << "Communicator is unavailable."; + } + if (communicator_->size() == 1) { + GTEST_SKIP() << "Needs at least 2 devices."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + + constexpr int64_t m = 1024; + constexpr int64_t k = 1024; + constexpr int64_t n = 1024; + NVF_ERROR(m % world_size == 0, "M must be divisible by world size."); + const int64_t m_per_rank = m / world_size; + + const auto cpu_float_opts = + at::TensorOptions().dtype(at::kFloat).device(at::kCPU); + const auto gpu_half_opts = + at::TensorOptions().dtype(at::kHalf).device(communicator_->device()); + + // Every rank builds identical global inputs from the same seed. + at::manual_seed(0); + at::Tensor a_full_cpu = at::randn({m, k}, cpu_float_opts); + at::Tensor b_full_cpu = at::randn({k, n}, cpu_float_opts); + + at::Tensor a_local_half = a_full_cpu + .slice(0, my_rank * m_per_rank, (my_rank + 1) * m_per_rank) + .to(gpu_half_opts.device(), at::kHalf); + at::Tensor b_full_half = b_full_cpu.to(gpu_half_opts.device(), at::kHalf); + + at::Tensor a_local_sym = SymmetricTensor::allocate( + {m_per_rank, k}, at::ScalarType::Half, communicator_->device()); + a_local_sym.copy_(a_local_half); + SymmetricTensor symmetric_a(a_local_sym); + symmetric_a.setupRemoteHandles("fused_remote_matmul_a"); + + std::vector host_remote_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + host_remote_ptrs[rank] = + reinterpret_cast(symmetric_a.remoteTensor(rank).data_ptr()); + } + + __half** device_remote_ptrs = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaMalloc(&device_remote_ptrs, world_size * sizeof(__half*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_remote_ptrs, + host_remote_ptrs.data(), + world_size * sizeof(__half*), + cudaMemcpyHostToDevice)); + + at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); + cudaStream_t stream = + at::cuda::getCurrentCUDAStream(static_cast(my_rank)).stream(); + + // Correctness check. + (void)timeFusedRemoteMatmulMs( + const_cast(device_remote_ptrs), + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + world_size, + m, + n, + k, + m_per_rank, + /*warmup_iters=*/3, + /*iters=*/1, + stream); + + at::Tensor c_ref_cpu = at::matmul(a_full_cpu, b_full_cpu); + at::Tensor c_out_cpu = c_out_half.cpu().to(at::kFloat); + EXPECT_TRUE(c_out_cpu.allclose(c_ref_cpu, 2e-1, 2e-1)) + << "Fused remote-pointer matmul output mismatch."; + + communicator_->barrier(); + constexpr int64_t warmup_iters = 8; + constexpr int64_t iters = 30; + const double ms_per_iter = timeFusedRemoteMatmulMs( + const_cast(device_remote_ptrs), + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + world_size, + m, + n, + k, + m_per_rank, + warmup_iters, + iters, + stream); + communicator_->barrier(); + + const double flops = 2.0 * static_cast(m) * static_cast(n) * + static_cast(k); + const double tflops = flops / (ms_per_iter * 1.0e9); + if (my_rank == 0) { + std::cout << "[perf] fused_remote_matmul M=" << m << " N=" << n + << " K=" << k << " world_size=" << world_size << " : " + << ms_per_iter << " ms/iter, " << tflops << " TFLOP/s" + << std::endl; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_remote_ptrs)); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu new file mode 100644 index 00000000000..21814b05b3c --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -0,0 +1,92 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include "cuda_utils.h" + +namespace nvfuser { + +namespace { + +__global__ void fusedRemoteMatmulKernel( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank) { + const int64_t row = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= m || col >= n) { + return; + } + + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + + float acc = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + const float a = __half2float(a_local[local_row * k + kk]); + const float b = __half2float(b_full[kk * n + col]); + acc += a * b; + } + c_out[row * n + col] = __float2half(acc); +} + +} // namespace + +double timeFusedRemoteMatmulMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t world_size, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + (void)world_size; + const dim3 block(16, 16); + const dim3 grid( + static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + + for (int64_t i = 0; i < warmup_iters; ++i) { + fusedRemoteMatmulKernel<<>>( + a_remote_shards, b_full, c_out, m, n, k, m_per_rank); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + fusedRemoteMatmulKernel<<>>( + a_remote_shards, b_full, c_out, m, n, k, m_per_rank); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + +} // namespace nvfuser From dd2e7ced4314e82185bae854babe74fb1a103875 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 01:46:49 -0800 Subject: [PATCH 03/23] add baseline torch eager nccl --- .../test_multidevice_fused_remote_matmul.cpp | 155 ++++++++++++++++-- 1 file changed, 143 insertions(+), 12 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 4ca715b8c91..7d48469b625 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -11,6 +11,7 @@ #include #include +#include "multidevice/communicator.h" #include "multidevice/symmetric_tensor.h" #include "tests/cpp/multidevice.h" @@ -29,9 +30,111 @@ double timeFusedRemoteMatmulMs( int64_t iters, cudaStream_t stream); -using FusedRemoteMatmulTest = MultiDeviceTest; +namespace { -TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { +enum class RemoteMatmulImpl { naiveFusedKernel, baselinePytorchEager }; + +const char* implName(RemoteMatmulImpl impl) { + switch (impl) { + case RemoteMatmulImpl::naiveFusedKernel: + return "naiveFusedKernel"; + case RemoteMatmulImpl::baselinePytorchEager: + return "baselinePytorchEager"; + } + NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); +} + +double timeBaselinePytorchEagerMs( + c10d::Backend* backend, + at::Tensor& a_local_half, + const at::Tensor& b_full_half, + at::Tensor& c_out_half, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + at::Tensor a_allgathered_half = at::empty( + {a_local_half.size(0) * backend->getSize(), a_local_half.size(1)}, + a_local_half.options()); + + for (int64_t i = 0; i < warmup_iters; ++i) { + backend->_allgather_base(a_allgathered_half, a_local_half)->wait(); + at::matmul_out(c_out_half, a_allgathered_half, b_full_half); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + backend->_allgather_base(a_allgathered_half, a_local_half)->wait(); + at::matmul_out(c_out_half, a_allgathered_half, b_full_half); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + +double timeImplementationMs( + RemoteMatmulImpl impl, + c10d::Backend* backend, + const __half* const* device_remote_ptrs, + at::Tensor& a_local_half, + const at::Tensor& b_full_half, + at::Tensor& c_out_half, + int64_t world_size, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + switch (impl) { + case RemoteMatmulImpl::naiveFusedKernel: + return timeFusedRemoteMatmulMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + world_size, + m, + n, + k, + m_per_rank, + warmup_iters, + iters, + stream); + case RemoteMatmulImpl::baselinePytorchEager: + NVF_CHECK( + backend != nullptr, + "baselinePytorchEager requires a valid NCCL process group backend."); + return timeBaselinePytorchEagerMs( + backend, + a_local_half, + b_full_half, + c_out_half, + warmup_iters, + iters, + stream); + } + NVF_ERROR(false, "Unsupported implementation enum: ", static_cast(impl)); +} + +} // namespace + +class FusedRemoteMatmulTest : public MultiDeviceTest, + public testing::WithParamInterface {}; + +TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { if (!communicator_->is_available()) { GTEST_SKIP() << "Communicator is unavailable."; } @@ -41,6 +144,18 @@ TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); + const auto impl = GetParam(); + + Team all_devices(world_size); + std::iota(all_devices.begin(), all_devices.end(), 0); + c10d::Backend* nccl_backend = nullptr; + if (impl == RemoteMatmulImpl::baselinePytorchEager) { + if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { + GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEager."; + } + nccl_backend = + communicator_->getBackendForTeam(all_devices, CommunicatorBackend::kNccl); + } constexpr int64_t m = 1024; constexpr int64_t k = 1024; @@ -89,10 +204,13 @@ TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { at::cuda::getCurrentCUDAStream(static_cast(my_rank)).stream(); // Correctness check. - (void)timeFusedRemoteMatmulMs( + (void)timeImplementationMs( + impl, + nccl_backend, const_cast(device_remote_ptrs), - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), + a_local_half, + b_full_half, + c_out_half, world_size, m, n, @@ -110,10 +228,13 @@ TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { communicator_->barrier(); constexpr int64_t warmup_iters = 8; constexpr int64_t iters = 30; - const double ms_per_iter = timeFusedRemoteMatmulMs( + const double ms_per_iter = timeImplementationMs( + impl, + nccl_backend, const_cast(device_remote_ptrs), - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), + a_local_half, + b_full_half, + c_out_half, world_size, m, n, @@ -128,13 +249,23 @@ TEST_F(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { static_cast(k); const double tflops = flops / (ms_per_iter * 1.0e9); if (my_rank == 0) { - std::cout << "[perf] fused_remote_matmul M=" << m << " N=" << n - << " K=" << k << " world_size=" << world_size << " : " - << ms_per_iter << " ms/iter, " << tflops << " TFLOP/s" - << std::endl; + std::cout << "[perf] fused_remote_matmul impl=" << implName(impl) + << " M=" << m << " N=" << n << " K=" << k + << " world_size=" << world_size << " : " << ms_per_iter + << " ms/iter, " << tflops << " TFLOP/s" << std::endl; } NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_remote_ptrs)); } +INSTANTIATE_TEST_SUITE_P( + , + FusedRemoteMatmulTest, + testing::Values( + RemoteMatmulImpl::naiveFusedKernel, + RemoteMatmulImpl::baselinePytorchEager), + [](const testing::TestParamInfo& info) { + return implName(info.param); + }); + } // namespace nvfuser From 90366f6374ad80b6a93625f22b6dc8668cef1841 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 01:53:12 -0800 Subject: [PATCH 04/23] add baseline torch eager cuda --- .../test_multidevice_fused_remote_matmul.cpp | 151 ++++++++++++++++-- 1 file changed, 138 insertions(+), 13 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 7d48469b625..2db83eced53 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -10,8 +10,15 @@ #include #include +#include +#include "fusion.h" +#include "host_ir/container.h" +#include "ir/builder.h" +#include "multidevice/communication.h" #include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/ipc_handle.h" #include "multidevice/symmetric_tensor.h" #include "tests/cpp/multidevice.h" @@ -32,14 +39,20 @@ double timeFusedRemoteMatmulMs( namespace { -enum class RemoteMatmulImpl { naiveFusedKernel, baselinePytorchEager }; +enum class RemoteMatmulImpl { + naiveFusedKernel, + baselinePytorchEagerNccl, + baselinePytorchEagerCuda +}; const char* implName(RemoteMatmulImpl impl) { switch (impl) { case RemoteMatmulImpl::naiveFusedKernel: return "naiveFusedKernel"; - case RemoteMatmulImpl::baselinePytorchEager: - return "baselinePytorchEager"; + case RemoteMatmulImpl::baselinePytorchEagerNccl: + return "baselinePytorchEagerNccl"; + case RemoteMatmulImpl::baselinePytorchEagerCuda: + return "baselinePytorchEagerCuda"; } NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); } @@ -84,11 +97,72 @@ double timeBaselinePytorchEagerMs( return static_cast(total_ms) / static_cast(iters); } +double timeBaselinePytorchEagerCudaMs( + Communication* communication, + SymMemForAllgather* allgather_handle, + at::Tensor& a_local_half, + at::Tensor& a_allgathered_half, + const at::Tensor& b_full_half, + at::Tensor& c_out_half, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + for (int64_t i = 0; i < warmup_iters; ++i) { + postWithCudaBackend( + communication, + a_local_half, + allgather_handle, + (CUstream)stream, + /*root=*/-1); + waitWithCudaBackend( + communication, + allgather_handle, + (CUstream)stream, + /*root=*/-1); + at::matmul_out(c_out_half, a_allgathered_half, b_full_half); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + postWithCudaBackend( + communication, + a_local_half, + allgather_handle, + (CUstream)stream, + /*root=*/-1); + waitWithCudaBackend( + communication, + allgather_handle, + (CUstream)stream, + /*root=*/-1); + at::matmul_out(c_out_half, a_allgathered_half, b_full_half); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + double timeImplementationMs( RemoteMatmulImpl impl, - c10d::Backend* backend, + c10d::Backend* nccl_backend, + Communication* cuda_allgather_communication, + SymMemForAllgather* cuda_allgather_handle, const __half* const* device_remote_ptrs, at::Tensor& a_local_half, + at::Tensor& a_allgathered_half_cuda, const at::Tensor& b_full_half, at::Tensor& c_out_half, int64_t world_size, @@ -113,13 +187,27 @@ double timeImplementationMs( warmup_iters, iters, stream); - case RemoteMatmulImpl::baselinePytorchEager: + case RemoteMatmulImpl::baselinePytorchEagerNccl: NVF_CHECK( - backend != nullptr, - "baselinePytorchEager requires a valid NCCL process group backend."); + nccl_backend != nullptr, + "baselinePytorchEagerNccl requires a valid NCCL process group backend."); return timeBaselinePytorchEagerMs( - backend, + nccl_backend, + a_local_half, + b_full_half, + c_out_half, + warmup_iters, + iters, + stream); + case RemoteMatmulImpl::baselinePytorchEagerCuda: + NVF_CHECK( + cuda_allgather_communication != nullptr && cuda_allgather_handle != nullptr, + "baselinePytorchEagerCuda requires initialized CUDA allgather resources."); + return timeBaselinePytorchEagerCudaMs( + cuda_allgather_communication, + cuda_allgather_handle, a_local_half, + a_allgathered_half_cuda, b_full_half, c_out_half, warmup_iters, @@ -148,14 +236,18 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { Team all_devices(world_size); std::iota(all_devices.begin(), all_devices.end(), 0); + c10d::Backend* nccl_backend = nullptr; - if (impl == RemoteMatmulImpl::baselinePytorchEager) { + if (impl == RemoteMatmulImpl::baselinePytorchEagerNccl) { if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { - GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEager."; + GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEagerNccl."; } nccl_backend = communicator_->getBackendForTeam(all_devices, CommunicatorBackend::kNccl); } + std::unique_ptr cuda_hic; + Communication* cuda_allgather_communication = nullptr; + std::unique_ptr cuda_allgather_handle; constexpr int64_t m = 1024; constexpr int64_t k = 1024; @@ -200,15 +292,44 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { cudaMemcpyHostToDevice)); at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); - cudaStream_t stream = - at::cuda::getCurrentCUDAStream(static_cast(my_rank)).stream(); + at::Tensor a_allgathered_half_cuda; + c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( + /*isHighPriority=*/false, + static_cast(communicator_->device().index())); + c10::cuda::CUDAStreamGuard stream_guard(test_stream); + cudaStream_t stream = test_stream.stream(); + + if (impl == RemoteMatmulImpl::baselinePytorchEagerCuda) { + cuda_hic = std::make_unique(); + FusionGuard fg(cuda_hic.get()); + auto* in_tv = makeContigTensor(2); + auto* out_tv = makeContigTensor(2); + DeviceMesh mesh = DeviceMesh::createForNumDevices(world_size); + in_tv->setDeviceMesh(mesh); + out_tv->setDeviceMesh(mesh); + cuda_allgather_communication = IrBuilder::create( + CommunicationType::Allgather, + out_tv, + in_tv, + all_devices, + /*root=*/-1, + RedOpType::UNUSED, + CommunicatorBackend::kCuda); + a_allgathered_half_cuda = SymmetricTensor::allocate( + {m, k}, at::ScalarType::Half, communicator_->device()); + cuda_allgather_handle = std::make_unique( + cuda_allgather_communication, a_allgathered_half_cuda); + } // Correctness check. (void)timeImplementationMs( impl, nccl_backend, + cuda_allgather_communication, + cuda_allgather_handle.get(), const_cast(device_remote_ptrs), a_local_half, + a_allgathered_half_cuda, b_full_half, c_out_half, world_size, @@ -231,8 +352,11 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const double ms_per_iter = timeImplementationMs( impl, nccl_backend, + cuda_allgather_communication, + cuda_allgather_handle.get(), const_cast(device_remote_ptrs), a_local_half, + a_allgathered_half_cuda, b_full_half, c_out_half, world_size, @@ -263,7 +387,8 @@ INSTANTIATE_TEST_SUITE_P( FusedRemoteMatmulTest, testing::Values( RemoteMatmulImpl::naiveFusedKernel, - RemoteMatmulImpl::baselinePytorchEager), + RemoteMatmulImpl::baselinePytorchEagerNccl, + RemoteMatmulImpl::baselinePytorchEagerCuda), [](const testing::TestParamInfo& info) { return implName(info.param); }); From dda6dd1817cd446364bf63f47e1b3775796d5952 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 02:08:19 -0800 Subject: [PATCH 05/23] centralize benchmark options, add time measurement mode and add_barrier_at_each_step --- .../test_multidevice_fused_remote_matmul.cpp | 219 ++++++++++-------- 1 file changed, 127 insertions(+), 92 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 2db83eced53..7ab29311ae4 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include "fusion.h" #include "host_ir/container.h" #include "ir/builder.h" @@ -45,6 +47,15 @@ enum class RemoteMatmulImpl { baselinePytorchEagerCuda }; +enum class TimeMeasurementMode { CudaEvents, CpuClock }; + +struct BenchmarkConfig { + int64_t warmup_iters; + int64_t iters; + TimeMeasurementMode time_mode; + bool barrier_at_each_iteration; +}; + const char* implName(RemoteMatmulImpl impl) { switch (impl) { case RemoteMatmulImpl::naiveFusedKernel: @@ -57,44 +68,78 @@ const char* implName(RemoteMatmulImpl impl) { NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); } +template +double benchmarkLoopMs( + const BenchmarkConfig& config, + Communicator* communicator, + cudaStream_t stream, + Fn&& run_once) { + for (int64_t i = 0; i < config.warmup_iters; ++i) { + if (config.barrier_at_each_iteration) { + communicator->barrier(); + } + run_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + if (config.time_mode == TimeMeasurementMode::CudaEvents) { + float total_ms = 0.0f; + for (int64_t i = 0; i < config.iters; ++i) { + if (config.barrier_at_each_iteration) { + communicator->barrier(); + } + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + float iter_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&iter_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + total_ms += iter_ms; + } + return static_cast(total_ms) / static_cast(config.iters); + } + + double total_ms = 0.0; + for (int64_t i = 0; i < config.iters; ++i) { + if (config.barrier_at_each_iteration) { + communicator->barrier(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + auto start = std::chrono::high_resolution_clock::now(); + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + auto stop = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = stop - start; + total_ms += elapsed.count(); + } + return total_ms / static_cast(config.iters); +} + double timeBaselinePytorchEagerMs( c10d::Backend* backend, at::Tensor& a_local_half, const at::Tensor& b_full_half, at::Tensor& c_out_half, - int64_t warmup_iters, - int64_t iters, + const BenchmarkConfig& config, + Communicator* communicator, cudaStream_t stream) { at::Tensor a_allgathered_half = at::empty( {a_local_half.size(0) * backend->getSize(), a_local_half.size(1)}, a_local_half.options()); - - for (int64_t i = 0; i < warmup_iters; ++i) { + auto run_once = [&]() { backend->_allgather_base(a_allgathered_half, a_local_half)->wait(); at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - backend->_allgather_base(a_allgathered_half, a_local_half)->wait(); - at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + }; + return benchmarkLoopMs(config, communicator, stream, run_once); } double timeBaselinePytorchEagerCudaMs( @@ -104,10 +149,10 @@ double timeBaselinePytorchEagerCudaMs( at::Tensor& a_allgathered_half, const at::Tensor& b_full_half, at::Tensor& c_out_half, - int64_t warmup_iters, - int64_t iters, + const BenchmarkConfig& config, + Communicator* communicator, cudaStream_t stream) { - for (int64_t i = 0; i < warmup_iters; ++i) { + auto run_once = [&]() { postWithCudaBackend( communication, a_local_half, @@ -120,43 +165,14 @@ double timeBaselinePytorchEagerCudaMs( (CUstream)stream, /*root=*/-1); at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - postWithCudaBackend( - communication, - a_local_half, - allgather_handle, - (CUstream)stream, - /*root=*/-1); - waitWithCudaBackend( - communication, - allgather_handle, - (CUstream)stream, - /*root=*/-1); - at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + }; + return benchmarkLoopMs(config, communicator, stream, run_once); } double timeImplementationMs( RemoteMatmulImpl impl, + const BenchmarkConfig& config, + Communicator* communicator, c10d::Backend* nccl_backend, Communication* cuda_allgather_communication, SymMemForAllgather* cuda_allgather_handle, @@ -170,23 +186,25 @@ double timeImplementationMs( int64_t n, int64_t k, int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, cudaStream_t stream) { switch (impl) { - case RemoteMatmulImpl::naiveFusedKernel: - return timeFusedRemoteMatmulMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - world_size, - m, - n, - k, - m_per_rank, - warmup_iters, - iters, - stream); + case RemoteMatmulImpl::naiveFusedKernel: { + auto run_once = [&]() { + timeFusedRemoteMatmulMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + world_size, + m, + n, + k, + m_per_rank, + /*warmup_iters=*/0, + /*iters=*/1, + stream); + }; + return benchmarkLoopMs(config, communicator, stream, run_once); + } case RemoteMatmulImpl::baselinePytorchEagerNccl: NVF_CHECK( nccl_backend != nullptr, @@ -196,8 +214,8 @@ double timeImplementationMs( a_local_half, b_full_half, c_out_half, - warmup_iters, - iters, + config, + communicator, stream); case RemoteMatmulImpl::baselinePytorchEagerCuda: NVF_CHECK( @@ -210,8 +228,8 @@ double timeImplementationMs( a_allgathered_half_cuda, b_full_half, c_out_half, - warmup_iters, - iters, + config, + communicator, stream); } NVF_ERROR(false, "Unsupported implementation enum: ", static_cast(impl)); @@ -220,7 +238,19 @@ double timeImplementationMs( } // namespace class FusedRemoteMatmulTest : public MultiDeviceTest, - public testing::WithParamInterface {}; + public testing::WithParamInterface { + protected: + static constexpr BenchmarkConfig kBenchmarkConfig = { + /*warmup_iters=*/8, + /*iters=*/30, + /*time_mode=*/TimeMeasurementMode::CpuClock, + /*barrier_at_each_iteration=*/false}; + static constexpr BenchmarkConfig kCorrectnessConfig = { + /*warmup_iters=*/3, + /*iters=*/1, + /*time_mode=*/TimeMeasurementMode::CpuClock, + /*barrier_at_each_iteration=*/false}; +}; TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { if (!communicator_->is_available()) { @@ -324,6 +354,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { // Correctness check. (void)timeImplementationMs( impl, + kCorrectnessConfig, + communicator_, nccl_backend, cuda_allgather_communication, cuda_allgather_handle.get(), @@ -337,8 +369,6 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { n, k, m_per_rank, - /*warmup_iters=*/3, - /*iters=*/1, stream); at::Tensor c_ref_cpu = at::matmul(a_full_cpu, b_full_cpu); @@ -347,10 +377,10 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { << "Fused remote-pointer matmul output mismatch."; communicator_->barrier(); - constexpr int64_t warmup_iters = 8; - constexpr int64_t iters = 30; - const double ms_per_iter = timeImplementationMs( + const double local_ms_per_iter = timeImplementationMs( impl, + kBenchmarkConfig, + communicator_, nccl_backend, cuda_allgather_communication, cuda_allgather_handle.get(), @@ -364,18 +394,23 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { n, k, m_per_rank, - warmup_iters, - iters, stream); communicator_->barrier(); + at::Tensor max_time_tensor = at::tensor( + {static_cast(local_ms_per_iter)}, + at::TensorOptions().dtype(at::kFloat).device(communicator_->device())); + std::vector time_tensors = {max_time_tensor}; + communicator_->getWorld()->allreduce(time_tensors, {c10d::ReduceOp::MAX})->wait(); + const double global_ms_per_iter = static_cast(max_time_tensor.item()); + const double flops = 2.0 * static_cast(m) * static_cast(n) * static_cast(k); - const double tflops = flops / (ms_per_iter * 1.0e9); + const double tflops = flops / (global_ms_per_iter * 1.0e9); if (my_rank == 0) { std::cout << "[perf] fused_remote_matmul impl=" << implName(impl) << " M=" << m << " N=" << n << " K=" << k - << " world_size=" << world_size << " : " << ms_per_iter + << " world_size=" << world_size << " : " << global_ms_per_iter << " ms/iter, " << tflops << " TFLOP/s" << std::endl; } From f9e7985a30874887748d01ba39de0f0616d56e70 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 02:22:22 -0800 Subject: [PATCH 06/23] cleanup --- .../test_multidevice_fused_remote_matmul.cpp | 201 ++++++++++++------ ..._multidevice_fused_remote_matmul_kernel.cu | 14 +- 2 files changed, 145 insertions(+), 70 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 7ab29311ae4..946020c3661 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -41,7 +41,20 @@ double timeFusedRemoteMatmulMs( namespace { -enum class RemoteMatmulImpl { +// Implementations compared by this benchmark: +// - naiveFusedKernel: +// A rank-local, handwritten remote-pointer matmul path. A is sharded on M and +// each output row pulls its owner shard directly from symmetric remote memory. +// This models a fused comm+compute style where data movement is embedded in +// the kernel access pattern. +// - baselinePytorchEagerNccl: +// Reference eager path using PyTorch process group NCCL allgather to rebuild +// full A on every rank, then regular eager matmul_out(A_full, B_full). +// - baselinePytorchEagerCuda: +// Same eager compute structure as the NCCL baseline, but communication uses +// nvFuser's CUDA backend allgather primitives (post/wait with symmetric-memory +// handles from cuda_p2p.h) before eager matmul_out. +enum class DistributedMatmulImpl { naiveFusedKernel, baselinePytorchEagerNccl, baselinePytorchEagerCuda @@ -49,6 +62,7 @@ enum class RemoteMatmulImpl { enum class TimeMeasurementMode { CudaEvents, CpuClock }; +// Centralized benchmark knobs used by every implementation path. struct BenchmarkConfig { int64_t warmup_iters; int64_t iters; @@ -56,13 +70,22 @@ struct BenchmarkConfig { bool barrier_at_each_iteration; }; -const char* implName(RemoteMatmulImpl impl) { +// Optional runtime objects required by specific implementations. +struct BenchmarkResources { + c10d::Backend* nccl_backend = nullptr; + std::unique_ptr cuda_hic; + Communication* cuda_allgather_communication = nullptr; + std::unique_ptr cuda_allgather_handle; + at::Tensor a_allgathered_half_cuda; +}; + +const char* implName(DistributedMatmulImpl impl) { switch (impl) { - case RemoteMatmulImpl::naiveFusedKernel: + case DistributedMatmulImpl::naiveFusedKernel: return "naiveFusedKernel"; - case RemoteMatmulImpl::baselinePytorchEagerNccl: + case DistributedMatmulImpl::baselinePytorchEagerNccl: return "baselinePytorchEagerNccl"; - case RemoteMatmulImpl::baselinePytorchEagerCuda: + case DistributedMatmulImpl::baselinePytorchEagerCuda: return "baselinePytorchEagerCuda"; } NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); @@ -74,6 +97,9 @@ double benchmarkLoopMs( Communicator* communicator, cudaStream_t stream, Fn&& run_once) { + NVF_CHECK(config.iters > 0, "iters must be > 0, got ", config.iters); + + // Warmup segment (not timed). for (int64_t i = 0; i < config.warmup_iters; ++i) { if (config.barrier_at_each_iteration) { communicator->barrier(); @@ -83,16 +109,19 @@ double benchmarkLoopMs( NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + // Timed segment with device-side timestamps. if (config.time_mode == TimeMeasurementMode::CudaEvents) { + // Time each iteration independently so optional barriers can remain outside + // the measured region while preserving per-iteration MAX reduction semantics. + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); float total_ms = 0.0f; for (int64_t i = 0; i < config.iters; ++i) { if (config.barrier_at_each_iteration) { communicator->barrier(); } - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); @@ -100,13 +129,14 @@ double benchmarkLoopMs( NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); float iter_ms = 0.0f; NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&iter_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); total_ms += iter_ms; } + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); return static_cast(total_ms) / static_cast(config.iters); } + // Timed segment with host-side timestamps (includes stream sync cost). double total_ms = 0.0; for (int64_t i = 0; i < config.iters; ++i) { if (config.barrier_at_each_iteration) { @@ -124,6 +154,58 @@ double benchmarkLoopMs( return total_ms / static_cast(config.iters); } +BenchmarkResources initBenchmarkResources( + DistributedMatmulImpl impl, + Communicator* communicator, + const Team& all_devices, + int64_t world_size, + int64_t m, + int64_t k) { + BenchmarkResources resources; + // NCCL eager baseline resources. + if (impl == DistributedMatmulImpl::baselinePytorchEagerNccl) { + if (!communicator->isBackendAvailable(CommunicatorBackend::kNccl)) { + return resources; + } + resources.nccl_backend = + communicator->getBackendForTeam(all_devices, CommunicatorBackend::kNccl); + } + + // CUDA backend eager baseline resources (symmetric allgather handle). + if (impl == DistributedMatmulImpl::baselinePytorchEagerCuda) { + resources.cuda_hic = std::make_unique(); + FusionGuard fg(resources.cuda_hic.get()); + auto* in_tv = makeContigTensor(2); + auto* out_tv = makeContigTensor(2); + DeviceMesh mesh = DeviceMesh::createForNumDevices(world_size); + in_tv->setDeviceMesh(mesh); + out_tv->setDeviceMesh(mesh); + resources.cuda_allgather_communication = IrBuilder::create( + CommunicationType::Allgather, + out_tv, + in_tv, + all_devices, + /*root=*/-1, + RedOpType::UNUSED, + CommunicatorBackend::kCuda); + resources.a_allgathered_half_cuda = SymmetricTensor::allocate( + {m, k}, at::ScalarType::Half, communicator->device()); + resources.cuda_allgather_handle = std::make_unique( + resources.cuda_allgather_communication, resources.a_allgathered_half_cuda); + } + return resources; +} + +double reduceMaxTimeMs(Communicator* communicator, double local_ms_per_iter) { + // Reduce per-rank timing with MAX so throughput reflects slowest rank. + at::Tensor max_time_tensor = at::tensor( + {static_cast(local_ms_per_iter)}, + at::TensorOptions().dtype(at::kFloat).device(communicator->device())); + std::vector time_tensors = {max_time_tensor}; + communicator->getWorld()->allreduce(time_tensors, {c10d::ReduceOp::MAX})->wait(); + return static_cast(max_time_tensor.item()); +} + double timeBaselinePytorchEagerMs( c10d::Backend* backend, at::Tensor& a_local_half, @@ -170,7 +252,7 @@ double timeBaselinePytorchEagerCudaMs( } double timeImplementationMs( - RemoteMatmulImpl impl, + DistributedMatmulImpl impl, const BenchmarkConfig& config, Communicator* communicator, c10d::Backend* nccl_backend, @@ -187,8 +269,9 @@ double timeImplementationMs( int64_t k, int64_t m_per_rank, cudaStream_t stream) { + // Dispatch to implementation-specific execution path. switch (impl) { - case RemoteMatmulImpl::naiveFusedKernel: { + case DistributedMatmulImpl::naiveFusedKernel: { auto run_once = [&]() { timeFusedRemoteMatmulMs( device_remote_ptrs, @@ -205,7 +288,7 @@ double timeImplementationMs( }; return benchmarkLoopMs(config, communicator, stream, run_once); } - case RemoteMatmulImpl::baselinePytorchEagerNccl: + case DistributedMatmulImpl::baselinePytorchEagerNccl: NVF_CHECK( nccl_backend != nullptr, "baselinePytorchEagerNccl requires a valid NCCL process group backend."); @@ -217,7 +300,7 @@ double timeImplementationMs( config, communicator, stream); - case RemoteMatmulImpl::baselinePytorchEagerCuda: + case DistributedMatmulImpl::baselinePytorchEagerCuda: NVF_CHECK( cuda_allgather_communication != nullptr && cuda_allgather_handle != nullptr, "baselinePytorchEagerCuda requires initialized CUDA allgather resources."); @@ -238,7 +321,8 @@ double timeImplementationMs( } // namespace class FusedRemoteMatmulTest : public MultiDeviceTest, - public testing::WithParamInterface { + public testing::WithParamInterface< + DistributedMatmulImpl> { protected: static constexpr BenchmarkConfig kBenchmarkConfig = { /*warmup_iters=*/8, @@ -252,7 +336,14 @@ class FusedRemoteMatmulTest : public MultiDeviceTest, /*barrier_at_each_iteration=*/false}; }; +// Benchmark context: +// - A is sharded on M across ranks, B is replicated. +// - We compare three execution paths under identical setup/validation: +// fused remote-pointer kernel, NCCL allgather+eager matmul, and CUDA-backend +// allgather+eager matmul. +// - Rank 0 reports throughput using MAX latency reduced across ranks. TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { + // ---------- Preconditions ---------- if (!communicator_->is_available()) { GTEST_SKIP() << "Communicator is unavailable."; } @@ -264,33 +355,23 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const int64_t my_rank = communicator_->deviceId(); const auto impl = GetParam(); + // ---------- Problem shape ---------- Team all_devices(world_size); std::iota(all_devices.begin(), all_devices.end(), 0); - c10d::Backend* nccl_backend = nullptr; - if (impl == RemoteMatmulImpl::baselinePytorchEagerNccl) { - if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { - GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEagerNccl."; - } - nccl_backend = - communicator_->getBackendForTeam(all_devices, CommunicatorBackend::kNccl); - } - std::unique_ptr cuda_hic; - Communication* cuda_allgather_communication = nullptr; - std::unique_ptr cuda_allgather_handle; - constexpr int64_t m = 1024; constexpr int64_t k = 1024; constexpr int64_t n = 1024; NVF_ERROR(m % world_size == 0, "M must be divisible by world size."); const int64_t m_per_rank = m / world_size; + // ---------- Inputs ---------- const auto cpu_float_opts = at::TensorOptions().dtype(at::kFloat).device(at::kCPU); const auto gpu_half_opts = at::TensorOptions().dtype(at::kHalf).device(communicator_->device()); - // Every rank builds identical global inputs from the same seed. + // Deterministic inputs on every rank for fair cross-impl comparison. at::manual_seed(0); at::Tensor a_full_cpu = at::randn({m, k}, cpu_float_opts); at::Tensor b_full_cpu = at::randn({k, n}, cpu_float_opts); @@ -321,6 +402,7 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { world_size * sizeof(__half*), cudaMemcpyHostToDevice)); + // ---------- Outputs and stream ---------- at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); at::Tensor a_allgathered_half_cuda; c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( @@ -329,36 +411,23 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { c10::cuda::CUDAStreamGuard stream_guard(test_stream); cudaStream_t stream = test_stream.stream(); - if (impl == RemoteMatmulImpl::baselinePytorchEagerCuda) { - cuda_hic = std::make_unique(); - FusionGuard fg(cuda_hic.get()); - auto* in_tv = makeContigTensor(2); - auto* out_tv = makeContigTensor(2); - DeviceMesh mesh = DeviceMesh::createForNumDevices(world_size); - in_tv->setDeviceMesh(mesh); - out_tv->setDeviceMesh(mesh); - cuda_allgather_communication = IrBuilder::create( - CommunicationType::Allgather, - out_tv, - in_tv, - all_devices, - /*root=*/-1, - RedOpType::UNUSED, - CommunicatorBackend::kCuda); - a_allgathered_half_cuda = SymmetricTensor::allocate( - {m, k}, at::ScalarType::Half, communicator_->device()); - cuda_allgather_handle = std::make_unique( - cuda_allgather_communication, a_allgathered_half_cuda); + auto resources = + initBenchmarkResources(impl, communicator_, all_devices, world_size, m, k); + if (impl == DistributedMatmulImpl::baselinePytorchEagerNccl && + resources.nccl_backend == nullptr) { + GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEagerNccl."; } + a_allgathered_half_cuda = resources.a_allgathered_half_cuda; - // Correctness check. + // ---------- Correctness ---------- + // Run once before validation to execute the selected implementation path. (void)timeImplementationMs( impl, kCorrectnessConfig, communicator_, - nccl_backend, - cuda_allgather_communication, - cuda_allgather_handle.get(), + resources.nccl_backend, + resources.cuda_allgather_communication, + resources.cuda_allgather_handle.get(), const_cast(device_remote_ptrs), a_local_half, a_allgathered_half_cuda, @@ -376,14 +445,15 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { EXPECT_TRUE(c_out_cpu.allclose(c_ref_cpu, 2e-1, 2e-1)) << "Fused remote-pointer matmul output mismatch."; + // ---------- Benchmark ---------- communicator_->barrier(); const double local_ms_per_iter = timeImplementationMs( impl, kBenchmarkConfig, communicator_, - nccl_backend, - cuda_allgather_communication, - cuda_allgather_handle.get(), + resources.nccl_backend, + resources.cuda_allgather_communication, + resources.cuda_allgather_handle.get(), const_cast(device_remote_ptrs), a_local_half, a_allgathered_half_cuda, @@ -396,14 +466,11 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { m_per_rank, stream); communicator_->barrier(); + // Distributed throughput is constrained by the slowest rank. + const double global_ms_per_iter = + reduceMaxTimeMs(communicator_, local_ms_per_iter); - at::Tensor max_time_tensor = at::tensor( - {static_cast(local_ms_per_iter)}, - at::TensorOptions().dtype(at::kFloat).device(communicator_->device())); - std::vector time_tensors = {max_time_tensor}; - communicator_->getWorld()->allreduce(time_tensors, {c10d::ReduceOp::MAX})->wait(); - const double global_ms_per_iter = static_cast(max_time_tensor.item()); - + // ---------- Reporting ---------- const double flops = 2.0 * static_cast(m) * static_cast(n) * static_cast(k); const double tflops = flops / (global_ms_per_iter * 1.0e9); @@ -421,10 +488,10 @@ INSTANTIATE_TEST_SUITE_P( , FusedRemoteMatmulTest, testing::Values( - RemoteMatmulImpl::naiveFusedKernel, - RemoteMatmulImpl::baselinePytorchEagerNccl, - RemoteMatmulImpl::baselinePytorchEagerCuda), - [](const testing::TestParamInfo& info) { + DistributedMatmulImpl::naiveFusedKernel, + DistributedMatmulImpl::baselinePytorchEagerNccl, + DistributedMatmulImpl::baselinePytorchEagerCuda), + [](const testing::TestParamInfo& info) { return implName(info.param); }); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 21814b05b3c..bd6cc6d9317 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -14,6 +14,10 @@ namespace nvfuser { namespace { +// Naive fused kernel: +// - A is row-sharded across ranks (axis M) +// - each output row reads from its owner rank shard via remote pointers +// - B is replicated __global__ void fusedRemoteMatmulKernel( const __half* const* a_remote_shards, const __half* b_full, @@ -28,6 +32,7 @@ __global__ void fusedRemoteMatmulKernel( return; } + // Map global row to the rank-local row in that rank's A shard. const int64_t owner_rank = row / m_per_rank; const int64_t local_row = row - owner_rank * m_per_rank; const __half* a_local = a_remote_shards[owner_rank]; @@ -61,9 +66,13 @@ double timeFusedRemoteMatmulMs( static_cast((n + block.x - 1) / block.x), static_cast((m + block.y - 1) / block.y)); - for (int64_t i = 0; i < warmup_iters; ++i) { + auto launch_once = [&]() { fusedRemoteMatmulKernel<<>>( a_remote_shards, b_full, c_out, m, n, k, m_per_rank); + }; + + for (int64_t i = 0; i < warmup_iters; ++i) { + launch_once(); } NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); @@ -75,8 +84,7 @@ double timeFusedRemoteMatmulMs( NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); for (int64_t i = 0; i < iters; ++i) { - fusedRemoteMatmulKernel<<>>( - a_remote_shards, b_full, c_out, m, n, k, m_per_rank); + launch_once(); } NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); From 08e6bb5e10d1d681cebab40f08749339aa02a31c Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 03:06:20 -0800 Subject: [PATCH 07/23] add fused staged kernels without overlap --- .../test_multidevice_fused_remote_matmul.cpp | 165 ++++++++++++++- ..._multidevice_fused_remote_matmul_kernel.cu | 199 +++++++++++++++++- 2 files changed, 361 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 946020c3661..0adfa0e5f6d 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -35,6 +35,38 @@ double timeFusedRemoteMatmulMs( int64_t n, int64_t k, int64_t m_per_rank, + int64_t block_x, + int64_t block_y, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +double timeSeparatedAllgatherMatmulThreadLoadMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +double timeSeparatedAllgatherMatmulMultimemMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered_multicast, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, int64_t warmup_iters, int64_t iters, cudaStream_t stream); @@ -54,14 +86,47 @@ namespace { // Same eager compute structure as the NCCL baseline, but communication uses // nvFuser's CUDA backend allgather primitives (post/wait with symmetric-memory // handles from cuda_p2p.h) before eager matmul_out. +// - stagedAllgatherComputeThreadLoad: +// Single fused kernel with explicit internal stages: +// (1) stage full A-row from remote pointers with regular thread loads, +// (2) compute matmul for that row. +// - stagedAllgatherComputeMultimem: +// Same fused-kernel staged structure, but stage-1 writes use multimem store +// instructions on a multicast pointer before compute. enum class DistributedMatmulImpl { naiveFusedKernel, baselinePytorchEagerNccl, - baselinePytorchEagerCuda + baselinePytorchEagerCuda, + stagedAllgatherComputeThreadLoad, + stagedAllgatherComputeMultimem }; enum class TimeMeasurementMode { CudaEvents, CpuClock }; +// Runtime kernel-launch knobs shared by all implementations. +enum class RuntimeParams { + NaiveBlockX, + NaiveBlockY, + StagedBlockThreads, + StagedGridBlocks +}; + +int64_t runtimeParam(RuntimeParams param) { + switch (param) { + case RuntimeParams::NaiveBlockX: + return 16; + case RuntimeParams::NaiveBlockY: + return 16; + case RuntimeParams::StagedBlockThreads: + return 256; + case RuntimeParams::StagedGridBlocks: + // <= 0 means auto-select from M in the kernel launcher. + return 0; + } + NVF_ERROR(false, "Unknown runtime parameter enum value."); + return 0; +} + // Centralized benchmark knobs used by every implementation path. struct BenchmarkConfig { int64_t warmup_iters; @@ -77,6 +142,9 @@ struct BenchmarkResources { Communication* cuda_allgather_communication = nullptr; std::unique_ptr cuda_allgather_handle; at::Tensor a_allgathered_half_cuda; + at::Tensor a_gathered_threadload; + at::Tensor a_gathered_multimem; + std::unique_ptr a_gathered_multimem_sym; }; const char* implName(DistributedMatmulImpl impl) { @@ -87,10 +155,23 @@ const char* implName(DistributedMatmulImpl impl) { return "baselinePytorchEagerNccl"; case DistributedMatmulImpl::baselinePytorchEagerCuda: return "baselinePytorchEagerCuda"; + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: + return "stagedAllgatherComputeThreadLoad"; + case DistributedMatmulImpl::stagedAllgatherComputeMultimem: + return "stagedAllgatherComputeMultimem"; } NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); } +bool isMulticastSupported(int64_t device_id) { + int is_multicast_supported = 0; + auto result = cuDeviceGetAttribute( + &is_multicast_supported, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + static_cast(device_id)); + return result == CUDA_SUCCESS && is_multicast_supported != 0; +} + template double benchmarkLoopMs( const BenchmarkConfig& config, @@ -193,6 +274,24 @@ BenchmarkResources initBenchmarkResources( resources.cuda_allgather_handle = std::make_unique( resources.cuda_allgather_communication, resources.a_allgathered_half_cuda); } + + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoad) { + resources.a_gathered_threadload = at::empty( + {m, k}, + at::TensorOptions() + .dtype(at::kHalf) + .device(communicator->device()) + .layout(at::kStrided)); + } + + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { + resources.a_gathered_multimem = SymmetricTensor::allocate( + {m, k}, at::ScalarType::Half, communicator->device()); + resources.a_gathered_multimem_sym = + std::make_unique(resources.a_gathered_multimem); + resources.a_gathered_multimem_sym->setupMulticast( + /*exporter_rank=*/0, "fused_remote_matmul_staged_multimem"); + } return resources; } @@ -261,6 +360,9 @@ double timeImplementationMs( const __half* const* device_remote_ptrs, at::Tensor& a_local_half, at::Tensor& a_allgathered_half_cuda, + at::Tensor& a_gathered_threadload, + at::Tensor& a_gathered_multimem, + SymmetricTensor* a_gathered_multimem_sym, const at::Tensor& b_full_half, at::Tensor& c_out_half, int64_t world_size, @@ -270,6 +372,13 @@ double timeImplementationMs( int64_t m_per_rank, cudaStream_t stream) { // Dispatch to implementation-specific execution path. + (void)a_gathered_multimem; + const int64_t naive_block_x = runtimeParam(RuntimeParams::NaiveBlockX); + const int64_t naive_block_y = runtimeParam(RuntimeParams::NaiveBlockY); + const int64_t staged_block_threads = + runtimeParam(RuntimeParams::StagedBlockThreads); + const int64_t staged_grid_blocks = + runtimeParam(RuntimeParams::StagedGridBlocks); switch (impl) { case DistributedMatmulImpl::naiveFusedKernel: { auto run_once = [&]() { @@ -282,6 +391,8 @@ double timeImplementationMs( n, k, m_per_rank, + naive_block_x, + naive_block_y, /*warmup_iters=*/0, /*iters=*/1, stream); @@ -314,6 +425,39 @@ double timeImplementationMs( config, communicator, stream); + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: + return timeSeparatedAllgatherMatmulThreadLoadMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + reinterpret_cast<__half*>(a_gathered_threadload.data_ptr()), + m, + n, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); + case DistributedMatmulImpl::stagedAllgatherComputeMultimem: + NVF_CHECK( + a_gathered_multimem_sym != nullptr, + "stagedAllgatherComputeMultimem requires multicast staging tensor."); + return timeSeparatedAllgatherMatmulMultimemMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), + m, + n, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); } NVF_ERROR(false, "Unsupported implementation enum: ", static_cast(impl)); } @@ -355,6 +499,11 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const int64_t my_rank = communicator_->deviceId(); const auto impl = GetParam(); + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem && + !isMulticastSupported(my_rank)) { + GTEST_SKIP() << "Multicast is not supported on this device."; + } + // ---------- Problem shape ---------- Team all_devices(world_size); std::iota(all_devices.begin(), all_devices.end(), 0); @@ -405,6 +554,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { // ---------- Outputs and stream ---------- at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); at::Tensor a_allgathered_half_cuda; + at::Tensor a_gathered_threadload; + at::Tensor a_gathered_multimem; c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( /*isHighPriority=*/false, static_cast(communicator_->device().index())); @@ -418,6 +569,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEagerNccl."; } a_allgathered_half_cuda = resources.a_allgathered_half_cuda; + a_gathered_threadload = resources.a_gathered_threadload; + a_gathered_multimem = resources.a_gathered_multimem; // ---------- Correctness ---------- // Run once before validation to execute the selected implementation path. @@ -431,6 +584,9 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const_cast(device_remote_ptrs), a_local_half, a_allgathered_half_cuda, + a_gathered_threadload, + a_gathered_multimem, + resources.a_gathered_multimem_sym.get(), b_full_half, c_out_half, world_size, @@ -457,6 +613,9 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const_cast(device_remote_ptrs), a_local_half, a_allgathered_half_cuda, + a_gathered_threadload, + a_gathered_multimem, + resources.a_gathered_multimem_sym.get(), b_full_half, c_out_half, world_size, @@ -490,7 +649,9 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( DistributedMatmulImpl::naiveFusedKernel, DistributedMatmulImpl::baselinePytorchEagerNccl, - DistributedMatmulImpl::baselinePytorchEagerCuda), + DistributedMatmulImpl::baselinePytorchEagerCuda, + DistributedMatmulImpl::stagedAllgatherComputeThreadLoad, + DistributedMatmulImpl::stagedAllgatherComputeMultimem), [](const testing::TestParamInfo& info) { return implName(info.param); }); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index bd6cc6d9317..4fb3870e14d 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -46,6 +46,100 @@ __global__ void fusedRemoteMatmulKernel( c_out[row * n + col] = __float2half(acc); } +// Fused kernel with explicit internal stages: +// 1) Allgather stage: materialize one full A row from remote shard. +// 2) Compute stage: matmul for that row across all output columns. +__global__ void fusedStagedThreadLoadKernel( + const __half* const* a_remote_shards, + __half* a_gathered, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + const __half* b_full, + __half* c_out) { + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + + // Stage 1: gather this row into staged global buffer. + for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) { + a_gathered[row * k + kk] = a_local[local_row * k + kk]; + } + __syncthreads(); + + // Stage 2: compute this row from staged global A. + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { + float acc = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + acc += __half2float(a_gathered[row * k + kk]) * + __half2float(b_full[kk * n + col]); + } + c_out[row * n + col] = __float2half(acc); + } + __syncthreads(); + } +} + +// Same fused structure as above, but stage-1 writes use multimem stores. +__global__ void fusedStagedMultimemKernel( + const __half* const* a_remote_shards, + __half* a_gathered_multicast, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + const __half* b_full, + __half* c_out) { + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + __half* a_row_stage = a_gathered_multicast + row * k; + + // Stage 1: gather row into multicast staging buffer via 16-byte vectors. + constexpr int64_t vec_elems = 8; // 8 * half = 16 bytes + const int64_t n_vec = k / vec_elems; + for (int64_t vec_i = threadIdx.x; vec_i < n_vec; vec_i += blockDim.x) { + const uint4 val = + reinterpret_cast(a_local + local_row * k)[vec_i]; + char* dst_byte = reinterpret_cast(a_row_stage) + vec_i * 16; + +#if __CUDA_ARCH__ >= 900 + asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" + : + : "l"((void*)dst_byte), + "f"(__int_as_float(static_cast(val.x))), + "f"(__int_as_float(static_cast(val.y))), + "f"(__int_as_float(static_cast(val.z))), + "f"(__int_as_float(static_cast(val.w))) + : "memory"); +#else + (void)val; + // Multimem path must never run on non-Hopper architectures. + asm volatile("trap;"); +#endif + } + for (int64_t kk = n_vec * vec_elems + threadIdx.x; kk < k; + kk += blockDim.x) { + a_row_stage[kk] = a_local[local_row * k + kk]; + } + __syncthreads(); + + // Stage 2: compute from staged multicast-backed row. + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { + float acc = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + acc += + __half2float(a_row_stage[kk]) * __half2float(b_full[kk * n + col]); + } + c_out[row * n + col] = __float2half(acc); + } + __syncthreads(); + } +} + } // namespace double timeFusedRemoteMatmulMs( @@ -57,11 +151,13 @@ double timeFusedRemoteMatmulMs( int64_t n, int64_t k, int64_t m_per_rank, + int64_t block_x, + int64_t block_y, int64_t warmup_iters, int64_t iters, cudaStream_t stream) { (void)world_size; - const dim3 block(16, 16); + const dim3 block(static_cast(block_x), static_cast(block_y)); const dim3 grid( static_cast((n + block.x - 1) / block.x), static_cast((m + block.y - 1) / block.y)); @@ -97,4 +193,105 @@ double timeFusedRemoteMatmulMs( return static_cast(total_ms) / static_cast(iters); } +double timeSeparatedAllgatherMatmulThreadLoadMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + + auto launch_once = [&]() { + fusedStagedThreadLoadKernel<<>>( + a_remote_shards, a_gathered, m, n, k, m_per_rank, b_full, c_out); + }; + + for (int64_t i = 0; i < warmup_iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + +double timeSeparatedAllgatherMatmulMultimemMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered_multicast, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + + auto launch_once = [&]() { + fusedStagedMultimemKernel<<>>( + a_remote_shards, + a_gathered_multicast, + m, + n, + k, + m_per_rank, + b_full, + c_out); + }; + + for (int64_t i = 0; i < warmup_iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + } // namespace nvfuser From d5b78ab42d90839bf2260a6f803a6450b960f68a Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 04:26:19 -0800 Subject: [PATCH 08/23] fix race condition in multimem by using semaphores --- .../test_multidevice_fused_remote_matmul.cpp | 59 ++++++++++++++- ..._multidevice_fused_remote_matmul_kernel.cu | 72 +++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 0adfa0e5f6d..acd9ef4c14a 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -61,6 +61,10 @@ double timeSeparatedAllgatherMatmulMultimemMs( const __half* b_full, __half* c_out, __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, int64_t m, int64_t n, int64_t k, @@ -145,6 +149,8 @@ struct BenchmarkResources { at::Tensor a_gathered_threadload; at::Tensor a_gathered_multimem; std::unique_ptr a_gathered_multimem_sym; + at::Tensor stage_semaphore_multimem; + std::unique_ptr stage_semaphore_multimem_sym; }; const char* implName(DistributedMatmulImpl impl) { @@ -291,6 +297,17 @@ BenchmarkResources initBenchmarkResources( std::make_unique(resources.a_gathered_multimem); resources.a_gathered_multimem_sym->setupMulticast( /*exporter_rank=*/0, "fused_remote_matmul_staged_multimem"); + + // Per-rank semaphore rows used by the fused multimem kernel barrier. + // Shape is [writer_rank, row, vec4-int] so each writer can publish one + // epoch per row, and each reader can wait on all writers for that row. + resources.stage_semaphore_multimem = SymmetricTensor::allocate( + {world_size, m, 4}, at::ScalarType::Int, communicator->device()); + resources.stage_semaphore_multimem.zero_(); + resources.stage_semaphore_multimem_sym = + std::make_unique(resources.stage_semaphore_multimem); + resources.stage_semaphore_multimem_sym->setupRemoteHandles( + "fused_remote_matmul_stage_semaphore"); } return resources; } @@ -363,8 +380,11 @@ double timeImplementationMs( at::Tensor& a_gathered_threadload, at::Tensor& a_gathered_multimem, SymmetricTensor* a_gathered_multimem_sym, + SymmetricTensor* stage_semaphore_multimem_sym, + int32_t* const* stage_semaphore_remote_ptrs, const at::Tensor& b_full_half, at::Tensor& c_out_half, + int64_t my_rank, int64_t world_size, int64_t m, int64_t n, @@ -442,13 +462,19 @@ double timeImplementationMs( stream); case DistributedMatmulImpl::stagedAllgatherComputeMultimem: NVF_CHECK( - a_gathered_multimem_sym != nullptr, - "stagedAllgatherComputeMultimem requires multicast staging tensor."); + a_gathered_multimem_sym != nullptr && + stage_semaphore_multimem_sym != nullptr, + "stagedAllgatherComputeMultimem requires staging and semaphore tensors."); return timeSeparatedAllgatherMatmulMultimemMs( device_remote_ptrs, reinterpret_cast(b_full_half.data_ptr()), reinterpret_cast<__half*>(c_out_half.data_ptr()), reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), + stage_semaphore_remote_ptrs, + reinterpret_cast( + stage_semaphore_multimem_sym->localTensor().data_ptr()), + my_rank, + world_size, m, n, k, @@ -572,6 +598,26 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_gathered_threadload = resources.a_gathered_threadload; a_gathered_multimem = resources.a_gathered_multimem; + int32_t** device_stage_semaphore_remote_ptrs = nullptr; + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { + NVF_CHECK( + resources.stage_semaphore_multimem_sym != nullptr, + "Missing staged multimem semaphore resources."); + std::vector host_stage_semaphore_remote_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + // Base pointer to that rank's local semaphore tensor in its VA space. + host_stage_semaphore_remote_ptrs[rank] = reinterpret_cast( + resources.stage_semaphore_multimem_sym->remoteTensor(rank).data_ptr()); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( + &device_stage_semaphore_remote_ptrs, world_size * sizeof(int32_t*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_stage_semaphore_remote_ptrs, + host_stage_semaphore_remote_ptrs.data(), + world_size * sizeof(int32_t*), + cudaMemcpyHostToDevice)); + } + // ---------- Correctness ---------- // Run once before validation to execute the selected implementation path. (void)timeImplementationMs( @@ -587,8 +633,11 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_gathered_threadload, a_gathered_multimem, resources.a_gathered_multimem_sym.get(), + resources.stage_semaphore_multimem_sym.get(), + const_cast(device_stage_semaphore_remote_ptrs), b_full_half, c_out_half, + my_rank, world_size, m, n, @@ -616,8 +665,11 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_gathered_threadload, a_gathered_multimem, resources.a_gathered_multimem_sym.get(), + resources.stage_semaphore_multimem_sym.get(), + const_cast(device_stage_semaphore_remote_ptrs), b_full_half, c_out_half, + my_rank, world_size, m, n, @@ -641,6 +693,9 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { } NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_remote_ptrs)); + if (device_stage_semaphore_remote_ptrs != nullptr) { + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_stage_semaphore_remote_ptrs)); + } } INSTANTIATE_TEST_SUITE_P( diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 4fb3870e14d..e8acc4a1c4c 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -7,6 +7,7 @@ // clang-format on #include +#include #include "cuda_utils.h" @@ -86,6 +87,11 @@ __global__ void fusedStagedThreadLoadKernel( __global__ void fusedStagedMultimemKernel( const __half* const* a_remote_shards, __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + int32_t launch_epoch_base, int64_t m, int64_t n, int64_t k, @@ -127,6 +133,58 @@ __global__ void fusedStagedMultimemKernel( } __syncthreads(); +#if __CUDA_ARCH__ >= 900 + // Cross-device barrier between stage-1 stores and stage-2 reads. + // Each rank publishes one epoch for (my_rank,row), then waits until + // all writer ranks have published the same epoch for that row. + constexpr int64_t kSemaphoreVecWidth = 4; + constexpr int64_t kMaxPollIters = 1LL << 26; + const int32_t launch_epoch = launch_epoch_base + 1; + int32_t* my_row_local = stage_semaphore_local + + (my_rank * m + row) * kSemaphoreVecWidth; + + if (threadIdx.x == 0) { + // Publish local completion to all peers' semaphore tensors. + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + my_row_local[vec_i] = launch_epoch; + } + __threadfence_system(); + for (int64_t peer = 0; peer < world_size; ++peer) { + int32_t* peer_row_remote = stage_semaphore_remote_ptrs[peer] + + (my_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + peer_row_remote[vec_i] = launch_epoch; + } + } + __threadfence_system(); + } + __syncthreads(); + + if (threadIdx.x == 0) { + // Wait until every writer rank has published this row's epoch. + for (int64_t rank = 0; rank < world_size; ++rank) { + auto* rank_epoch_ptr = reinterpret_cast( + stage_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); + int64_t spins = 0; + while (atomicAdd(rank_epoch_ptr, 0U) < + static_cast(launch_epoch)) { + ++spins; + if (spins > kMaxPollIters) { + asm volatile("trap;"); + } + } + } + } + __syncthreads(); +#else + (void)stage_semaphore_remote_ptrs; + (void)stage_semaphore_local; + (void)my_rank; + (void)world_size; + (void)launch_epoch_base; + asm volatile("trap;"); +#endif + // Stage 2: compute from staged multicast-backed row. for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { float acc = 0.0f; @@ -245,6 +303,10 @@ double timeSeparatedAllgatherMatmulMultimemMs( const __half* b_full, __half* c_out, __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, int64_t m, int64_t n, int64_t k, @@ -256,17 +318,27 @@ double timeSeparatedAllgatherMatmulMultimemMs( cudaStream_t stream) { const dim3 block(static_cast(block_threads)); const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + int64_t launch_epoch_base = 0; auto launch_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "Multimem semaphore epoch overflow."); fusedStagedMultimemKernel<<>>( a_remote_shards, a_gathered_multicast, + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), m, n, k, m_per_rank, b_full, c_out); + ++launch_epoch_base; }; for (int64_t i = 0; i < warmup_iters; ++i) { From b9240f7e46fe4b9bd112382ddc1bfed0096d4ea7 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 04:37:11 -0800 Subject: [PATCH 09/23] fix race condition in entering all kernels by using semaphores --- .../test_multidevice_fused_remote_matmul.cpp | 136 ++++++++++++- ..._multidevice_fused_remote_matmul_kernel.cu | 185 ++++++++++++++++++ 2 files changed, 320 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index acd9ef4c14a..37402033878 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -56,6 +56,27 @@ double timeSeparatedAllgatherMatmulThreadLoadMs( int64_t iters, cudaStream_t stream); +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, @@ -94,6 +115,9 @@ namespace { // Single fused kernel with explicit internal stages: // (1) stage full A-row from remote pointers with regular thread loads, // (2) compute matmul for that row. +// - stagedAllgatherComputeThreadLoadSynchronized: +// Same thread-load staged kernel as above, but with fused cross-rank ready/done +// semaphores so remote reads are safe even when producer shards are mutable. // - stagedAllgatherComputeMultimem: // Same fused-kernel staged structure, but stage-1 writes use multimem store // instructions on a multicast pointer before compute. @@ -102,6 +126,7 @@ enum class DistributedMatmulImpl { baselinePytorchEagerNccl, baselinePytorchEagerCuda, stagedAllgatherComputeThreadLoad, + stagedAllgatherComputeThreadLoadSynchronized, stagedAllgatherComputeMultimem }; @@ -151,6 +176,10 @@ struct BenchmarkResources { std::unique_ptr a_gathered_multimem_sym; at::Tensor stage_semaphore_multimem; std::unique_ptr stage_semaphore_multimem_sym; + at::Tensor threadload_ready_semaphore; + std::unique_ptr threadload_ready_semaphore_sym; + at::Tensor threadload_done_semaphore; + std::unique_ptr threadload_done_semaphore_sym; }; const char* implName(DistributedMatmulImpl impl) { @@ -163,6 +192,8 @@ const char* implName(DistributedMatmulImpl impl) { return "baselinePytorchEagerCuda"; case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: return "stagedAllgatherComputeThreadLoad"; + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: + return "stagedAllgatherComputeThreadLoadSynchronized"; case DistributedMatmulImpl::stagedAllgatherComputeMultimem: return "stagedAllgatherComputeMultimem"; } @@ -281,7 +312,8 @@ BenchmarkResources initBenchmarkResources( resources.cuda_allgather_communication, resources.a_allgathered_half_cuda); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoad) { + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoad || + impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { resources.a_gathered_threadload = at::empty( {m, k}, at::TensorOptions() @@ -290,6 +322,25 @@ BenchmarkResources initBenchmarkResources( .layout(at::kStrided)); } + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { + // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. + resources.threadload_ready_semaphore = SymmetricTensor::allocate( + {world_size, m, 4}, at::ScalarType::Int, communicator->device()); + resources.threadload_ready_semaphore.zero_(); + resources.threadload_ready_semaphore_sym = std::make_unique( + resources.threadload_ready_semaphore); + resources.threadload_ready_semaphore_sym->setupRemoteHandles( + "fused_remote_matmul_threadload_ready"); + + resources.threadload_done_semaphore = SymmetricTensor::allocate( + {world_size, m, 4}, at::ScalarType::Int, communicator->device()); + resources.threadload_done_semaphore.zero_(); + resources.threadload_done_semaphore_sym = std::make_unique( + resources.threadload_done_semaphore); + resources.threadload_done_semaphore_sym->setupRemoteHandles( + "fused_remote_matmul_threadload_done"); + } + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { resources.a_gathered_multimem = SymmetricTensor::allocate( {m, k}, at::ScalarType::Half, communicator->device()); @@ -379,8 +430,12 @@ double timeImplementationMs( at::Tensor& a_allgathered_half_cuda, at::Tensor& a_gathered_threadload, at::Tensor& a_gathered_multimem, + SymmetricTensor* threadload_ready_semaphore_sym, + SymmetricTensor* threadload_done_semaphore_sym, SymmetricTensor* a_gathered_multimem_sym, SymmetricTensor* stage_semaphore_multimem_sym, + int32_t* const* threadload_ready_semaphore_remote_ptrs, + int32_t* const* threadload_done_semaphore_remote_ptrs, int32_t* const* stage_semaphore_remote_ptrs, const at::Tensor& b_full_half, at::Tensor& c_out_half, @@ -460,6 +515,35 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: + NVF_CHECK( + threadload_ready_semaphore_sym != nullptr && + threadload_done_semaphore_sym != nullptr && + threadload_ready_semaphore_remote_ptrs != nullptr && + threadload_done_semaphore_remote_ptrs != nullptr, + "stagedAllgatherComputeThreadLoadSynchronized requires semaphore resources."); + return timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + reinterpret_cast<__half*>(a_gathered_threadload.data_ptr()), + threadload_ready_semaphore_remote_ptrs, + reinterpret_cast( + threadload_ready_semaphore_sym->localTensor().data_ptr()), + threadload_done_semaphore_remote_ptrs, + reinterpret_cast( + threadload_done_semaphore_sym->localTensor().data_ptr()), + my_rank, + world_size, + m, + n, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::stagedAllgatherComputeMultimem: NVF_CHECK( a_gathered_multimem_sym != nullptr && @@ -599,6 +683,39 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_gathered_multimem = resources.a_gathered_multimem; int32_t** device_stage_semaphore_remote_ptrs = nullptr; + int32_t** device_threadload_ready_semaphore_remote_ptrs = nullptr; + int32_t** device_threadload_done_semaphore_remote_ptrs = nullptr; + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { + NVF_CHECK( + resources.threadload_ready_semaphore_sym != nullptr && + resources.threadload_done_semaphore_sym != nullptr, + "Missing synchronized threadload semaphore resources."); + std::vector host_ready_remote_ptrs(world_size); + std::vector host_done_remote_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + host_ready_remote_ptrs[rank] = reinterpret_cast( + resources.threadload_ready_semaphore_sym->remoteTensor(rank).data_ptr()); + host_done_remote_ptrs[rank] = reinterpret_cast( + resources.threadload_done_semaphore_sym->remoteTensor(rank).data_ptr()); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( + &device_threadload_ready_semaphore_remote_ptrs, + world_size * sizeof(int32_t*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_threadload_ready_semaphore_remote_ptrs, + host_ready_remote_ptrs.data(), + world_size * sizeof(int32_t*), + cudaMemcpyHostToDevice)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( + &device_threadload_done_semaphore_remote_ptrs, + world_size * sizeof(int32_t*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_threadload_done_semaphore_remote_ptrs, + host_done_remote_ptrs.data(), + world_size * sizeof(int32_t*), + cudaMemcpyHostToDevice)); + } + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { NVF_CHECK( resources.stage_semaphore_multimem_sym != nullptr, @@ -632,8 +749,12 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_allgathered_half_cuda, a_gathered_threadload, a_gathered_multimem, + resources.threadload_ready_semaphore_sym.get(), + resources.threadload_done_semaphore_sym.get(), resources.a_gathered_multimem_sym.get(), resources.stage_semaphore_multimem_sym.get(), + const_cast(device_threadload_ready_semaphore_remote_ptrs), + const_cast(device_threadload_done_semaphore_remote_ptrs), const_cast(device_stage_semaphore_remote_ptrs), b_full_half, c_out_half, @@ -664,8 +785,12 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_allgathered_half_cuda, a_gathered_threadload, a_gathered_multimem, + resources.threadload_ready_semaphore_sym.get(), + resources.threadload_done_semaphore_sym.get(), resources.a_gathered_multimem_sym.get(), resources.stage_semaphore_multimem_sym.get(), + const_cast(device_threadload_ready_semaphore_remote_ptrs), + const_cast(device_threadload_done_semaphore_remote_ptrs), const_cast(device_stage_semaphore_remote_ptrs), b_full_half, c_out_half, @@ -696,6 +821,14 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { if (device_stage_semaphore_remote_ptrs != nullptr) { NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_stage_semaphore_remote_ptrs)); } + if (device_threadload_ready_semaphore_remote_ptrs != nullptr) { + NVFUSER_CUDA_RT_SAFE_CALL( + cudaFree(device_threadload_ready_semaphore_remote_ptrs)); + } + if (device_threadload_done_semaphore_remote_ptrs != nullptr) { + NVFUSER_CUDA_RT_SAFE_CALL( + cudaFree(device_threadload_done_semaphore_remote_ptrs)); + } } INSTANTIATE_TEST_SUITE_P( @@ -706,6 +839,7 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::baselinePytorchEagerNccl, DistributedMatmulImpl::baselinePytorchEagerCuda, DistributedMatmulImpl::stagedAllgatherComputeThreadLoad, + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized, DistributedMatmulImpl::stagedAllgatherComputeMultimem), [](const testing::TestParamInfo& info) { return implName(info.param); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index e8acc4a1c4c..8e67fc1baca 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -83,6 +83,119 @@ __global__ void fusedStagedThreadLoadKernel( } } +__global__ void fusedStagedThreadLoadSynchronizedKernel( + const __half* const* a_remote_shards, + __half* a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + int32_t launch_epoch_base, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + const __half* b_full, + __half* c_out) { + constexpr int64_t kSemaphoreVecWidth = 4; + constexpr int64_t kMaxPollIters = 1LL << 26; + const int32_t launch_epoch = launch_epoch_base + 1; + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + int32_t* my_ready_row_local = + ready_semaphore_local + (my_rank * m + row) * kSemaphoreVecWidth; + int32_t* my_done_row_local = + done_semaphore_local + (my_rank * m + row) * kSemaphoreVecWidth; + + if (threadIdx.x == 0) { + // Publish ready epoch to all ranks before remote reads. + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + my_ready_row_local[vec_i] = launch_epoch; + } + __threadfence_system(); + for (int64_t peer = 0; peer < world_size; ++peer) { + int32_t* peer_ready_row_remote = ready_semaphore_remote_ptrs[peer] + + (my_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + peer_ready_row_remote[vec_i] = launch_epoch; + } + } + __threadfence_system(); + } + __syncthreads(); + + if (threadIdx.x == 0) { + for (int64_t rank = 0; rank < world_size; ++rank) { + auto* rank_ready_epoch_ptr = reinterpret_cast( + ready_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); + int64_t spins = 0; + while (atomicAdd(rank_ready_epoch_ptr, 0U) < + static_cast(launch_epoch)) { + ++spins; + if (spins > kMaxPollIters) { + asm volatile("trap;"); + } + } + } + } + __syncthreads(); + + // Stage 1: gather this row into staged global buffer. + for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) { + a_gathered[row * k + kk] = a_local[local_row * k + kk]; + } + __syncthreads(); + + // Stage 2: compute this row from staged global A. + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { + float acc = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + acc += __half2float(a_gathered[row * k + kk]) * + __half2float(b_full[kk * n + col]); + } + c_out[row * n + col] = __float2half(acc); + } + __syncthreads(); + + if (threadIdx.x == 0) { + // Publish done epoch so producers can safely mutate next iteration. + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + my_done_row_local[vec_i] = launch_epoch; + } + __threadfence_system(); + for (int64_t peer = 0; peer < world_size; ++peer) { + int32_t* peer_done_row_remote = done_semaphore_remote_ptrs[peer] + + (my_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + peer_done_row_remote[vec_i] = launch_epoch; + } + } + __threadfence_system(); + } + __syncthreads(); + + if (threadIdx.x == 0) { + for (int64_t rank = 0; rank < world_size; ++rank) { + auto* rank_done_epoch_ptr = reinterpret_cast( + done_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); + int64_t spins = 0; + while (atomicAdd(rank_done_epoch_ptr, 0U) < + static_cast(launch_epoch)) { + ++spins; + if (spins > kMaxPollIters) { + asm volatile("trap;"); + } + } + } + } + __syncthreads(); + } +} + // Same fused structure as above, but stage-1 writes use multimem stores. __global__ void fusedStagedMultimemKernel( const __half* const* a_remote_shards, @@ -298,6 +411,78 @@ double timeSeparatedAllgatherMatmulThreadLoadMs( return static_cast(total_ms) / static_cast(iters); } +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + __half* a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + int64_t launch_epoch_base = 0; + + auto launch_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "ThreadLoad synchronized semaphore epoch overflow."); + fusedStagedThreadLoadSynchronizedKernel<<>>( + a_remote_shards, + a_gathered, + ready_semaphore_remote_ptrs, + ready_semaphore_local, + done_semaphore_remote_ptrs, + done_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + m, + n, + k, + m_per_rank, + b_full, + c_out); + ++launch_epoch_base; + }; + + for (int64_t i = 0; i < warmup_iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, From 70b7ff0be030443646a7b2f75ebfeaf3c9620af6 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 04:59:39 -0800 Subject: [PATCH 10/23] refactor --- csrc/multidevice/symmetric_tensor.cpp | 23 ++ csrc/multidevice/symmetric_tensor.h | 3 + .../test_multidevice_fused_remote_matmul.cpp | 169 +++------- ..._multidevice_fused_remote_matmul_kernel.cu | 294 +++++++----------- 4 files changed, 185 insertions(+), 304 deletions(-) diff --git a/csrc/multidevice/symmetric_tensor.cpp b/csrc/multidevice/symmetric_tensor.cpp index 902ec5a3a32..a495b7e82dc 100644 --- a/csrc/multidevice/symmetric_tensor.cpp +++ b/csrc/multidevice/symmetric_tensor.cpp @@ -253,6 +253,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) } SymmetricTensor::~SymmetricTensor() { + if (device_peer_ptrs_ != nullptr) { + cudaFree(device_peer_ptrs_); + device_peer_ptrs_ = nullptr; + } + #if (CUDA_VERSION >= 13000) if (is_multicast_setup_) { if (mc_base_ptr_) { @@ -389,6 +394,24 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { .device(at::kCUDA, rank)); } +void** SymmetricTensor::devicePeerPointers() const { + NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup"); + if (device_peer_ptrs_ == nullptr) { + std::vector host_peer_ptrs(world_size_); + for (int64_t rank = 0; rank < world_size_; ++rank) { + host_peer_ptrs[rank] = reinterpret_cast(remote_ptrs_[rank]); + } + NVFUSER_CUDA_RT_SAFE_CALL( + cudaMalloc(&device_peer_ptrs_, world_size_ * sizeof(void*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_peer_ptrs_, + host_peer_ptrs.data(), + world_size_ * sizeof(void*), + cudaMemcpyHostToDevice)); + } + return device_peer_ptrs_; +} + void* SymmetricTensor::multicastPtr() const { NVF_CHECK(is_multicast_setup_, "Multicast not setup"); return mc_ptr_; diff --git a/csrc/multidevice/symmetric_tensor.h b/csrc/multidevice/symmetric_tensor.h index 5608153e0ce..46ac7da1c7c 100644 --- a/csrc/multidevice/symmetric_tensor.h +++ b/csrc/multidevice/symmetric_tensor.h @@ -51,6 +51,8 @@ class SymmetricTensor { // Setup remote access (lazy, init-once) void setupRemoteHandles(const std::string& tag = ""); at::Tensor remoteTensor(int64_t rank) const; + // Returns a device pointer table of peer pointers (void** on device). + void** devicePeerPointers() const; // Setup multicast (CUDA 13.0+, init-once) void setupMulticast(int64_t exporter_rank, const std::string& tag = ""); @@ -79,6 +81,7 @@ class SymmetricTensor { int peer_fd_{-1}; bool is_contiguous_view_setup_ = false; at::Tensor contiguous_view_; + mutable void** device_peer_ptrs_ = nullptr; }; } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 37402033878..ca85f894466 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -646,20 +646,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { SymmetricTensor symmetric_a(a_local_sym); symmetric_a.setupRemoteHandles("fused_remote_matmul_a"); - std::vector host_remote_ptrs(world_size); - for (int64_t rank = 0; rank < world_size; ++rank) { - host_remote_ptrs[rank] = - reinterpret_cast(symmetric_a.remoteTensor(rank).data_ptr()); - } - - __half** device_remote_ptrs = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL( - cudaMalloc(&device_remote_ptrs, world_size * sizeof(__half*))); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - device_remote_ptrs, - host_remote_ptrs.data(), - world_size * sizeof(__half*), - cudaMemcpyHostToDevice)); + const __half* const* device_remote_ptrs = + reinterpret_cast(symmetric_a.devicePeerPointers()); // ---------- Outputs and stream ---------- at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); @@ -682,89 +670,65 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_gathered_threadload = resources.a_gathered_threadload; a_gathered_multimem = resources.a_gathered_multimem; - int32_t** device_stage_semaphore_remote_ptrs = nullptr; - int32_t** device_threadload_ready_semaphore_remote_ptrs = nullptr; - int32_t** device_threadload_done_semaphore_remote_ptrs = nullptr; + int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; + int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; + int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && resources.threadload_done_semaphore_sym != nullptr, "Missing synchronized threadload semaphore resources."); - std::vector host_ready_remote_ptrs(world_size); - std::vector host_done_remote_ptrs(world_size); - for (int64_t rank = 0; rank < world_size; ++rank) { - host_ready_remote_ptrs[rank] = reinterpret_cast( - resources.threadload_ready_semaphore_sym->remoteTensor(rank).data_ptr()); - host_done_remote_ptrs[rank] = reinterpret_cast( - resources.threadload_done_semaphore_sym->remoteTensor(rank).data_ptr()); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( - &device_threadload_ready_semaphore_remote_ptrs, - world_size * sizeof(int32_t*))); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - device_threadload_ready_semaphore_remote_ptrs, - host_ready_remote_ptrs.data(), - world_size * sizeof(int32_t*), - cudaMemcpyHostToDevice)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( - &device_threadload_done_semaphore_remote_ptrs, - world_size * sizeof(int32_t*))); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - device_threadload_done_semaphore_remote_ptrs, - host_done_remote_ptrs.data(), - world_size * sizeof(int32_t*), - cudaMemcpyHostToDevice)); + device_threadload_ready_semaphore_remote_ptrs = + reinterpret_cast( + resources.threadload_ready_semaphore_sym->devicePeerPointers()); + device_threadload_done_semaphore_remote_ptrs = + reinterpret_cast( + resources.threadload_done_semaphore_sym->devicePeerPointers()); } if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { NVF_CHECK( resources.stage_semaphore_multimem_sym != nullptr, "Missing staged multimem semaphore resources."); - std::vector host_stage_semaphore_remote_ptrs(world_size); - for (int64_t rank = 0; rank < world_size; ++rank) { - // Base pointer to that rank's local semaphore tensor in its VA space. - host_stage_semaphore_remote_ptrs[rank] = reinterpret_cast( - resources.stage_semaphore_multimem_sym->remoteTensor(rank).data_ptr()); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc( - &device_stage_semaphore_remote_ptrs, world_size * sizeof(int32_t*))); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - device_stage_semaphore_remote_ptrs, - host_stage_semaphore_remote_ptrs.data(), - world_size * sizeof(int32_t*), - cudaMemcpyHostToDevice)); + device_stage_semaphore_remote_ptrs = + reinterpret_cast( + resources.stage_semaphore_multimem_sym->devicePeerPointers()); } + auto run_implementation = [&](const BenchmarkConfig& config) { + return timeImplementationMs( + impl, + config, + communicator_, + resources.nccl_backend, + resources.cuda_allgather_communication, + resources.cuda_allgather_handle.get(), + device_remote_ptrs, + a_local_half, + a_allgathered_half_cuda, + a_gathered_threadload, + a_gathered_multimem, + resources.threadload_ready_semaphore_sym.get(), + resources.threadload_done_semaphore_sym.get(), + resources.a_gathered_multimem_sym.get(), + resources.stage_semaphore_multimem_sym.get(), + device_threadload_ready_semaphore_remote_ptrs, + device_threadload_done_semaphore_remote_ptrs, + device_stage_semaphore_remote_ptrs, + b_full_half, + c_out_half, + my_rank, + world_size, + m, + n, + k, + m_per_rank, + stream); + }; + // ---------- Correctness ---------- // Run once before validation to execute the selected implementation path. - (void)timeImplementationMs( - impl, - kCorrectnessConfig, - communicator_, - resources.nccl_backend, - resources.cuda_allgather_communication, - resources.cuda_allgather_handle.get(), - const_cast(device_remote_ptrs), - a_local_half, - a_allgathered_half_cuda, - a_gathered_threadload, - a_gathered_multimem, - resources.threadload_ready_semaphore_sym.get(), - resources.threadload_done_semaphore_sym.get(), - resources.a_gathered_multimem_sym.get(), - resources.stage_semaphore_multimem_sym.get(), - const_cast(device_threadload_ready_semaphore_remote_ptrs), - const_cast(device_threadload_done_semaphore_remote_ptrs), - const_cast(device_stage_semaphore_remote_ptrs), - b_full_half, - c_out_half, - my_rank, - world_size, - m, - n, - k, - m_per_rank, - stream); + (void)run_implementation(kCorrectnessConfig); at::Tensor c_ref_cpu = at::matmul(a_full_cpu, b_full_cpu); at::Tensor c_out_cpu = c_out_half.cpu().to(at::kFloat); @@ -773,34 +737,7 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { // ---------- Benchmark ---------- communicator_->barrier(); - const double local_ms_per_iter = timeImplementationMs( - impl, - kBenchmarkConfig, - communicator_, - resources.nccl_backend, - resources.cuda_allgather_communication, - resources.cuda_allgather_handle.get(), - const_cast(device_remote_ptrs), - a_local_half, - a_allgathered_half_cuda, - a_gathered_threadload, - a_gathered_multimem, - resources.threadload_ready_semaphore_sym.get(), - resources.threadload_done_semaphore_sym.get(), - resources.a_gathered_multimem_sym.get(), - resources.stage_semaphore_multimem_sym.get(), - const_cast(device_threadload_ready_semaphore_remote_ptrs), - const_cast(device_threadload_done_semaphore_remote_ptrs), - const_cast(device_stage_semaphore_remote_ptrs), - b_full_half, - c_out_half, - my_rank, - world_size, - m, - n, - k, - m_per_rank, - stream); + const double local_ms_per_iter = run_implementation(kBenchmarkConfig); communicator_->barrier(); // Distributed throughput is constrained by the slowest rank. const double global_ms_per_iter = @@ -817,18 +754,6 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { << " ms/iter, " << tflops << " TFLOP/s" << std::endl; } - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_remote_ptrs)); - if (device_stage_semaphore_remote_ptrs != nullptr) { - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(device_stage_semaphore_remote_ptrs)); - } - if (device_threadload_ready_semaphore_remote_ptrs != nullptr) { - NVFUSER_CUDA_RT_SAFE_CALL( - cudaFree(device_threadload_ready_semaphore_remote_ptrs)); - } - if (device_threadload_done_semaphore_remote_ptrs != nullptr) { - NVFUSER_CUDA_RT_SAFE_CALL( - cudaFree(device_threadload_done_semaphore_remote_ptrs)); - } } INSTANTIATE_TEST_SUITE_P( diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 8e67fc1baca..4b8307b6375 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -15,6 +15,83 @@ namespace nvfuser { namespace { +constexpr int64_t kSemaphoreVecWidth = 4; +constexpr int64_t kMaxSemaphorePollIters = 1LL << 26; + +__device__ inline void publishEpochToAllRanks( + int32_t* const* remote_semaphore_ptrs, + int32_t* local_semaphore, + int64_t writer_rank, + int64_t row, + int64_t m, + int64_t world_size, + int32_t epoch) { + int32_t* my_row_local = + local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + my_row_local[vec_i] = epoch; + } + __threadfence_system(); + for (int64_t peer = 0; peer < world_size; ++peer) { + int32_t* peer_row_remote = + remote_semaphore_ptrs[peer] + (writer_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + peer_row_remote[vec_i] = epoch; + } + } + __threadfence_system(); +} + +__device__ inline void waitForEpochFromAllRanks( + int32_t* local_semaphore, + int64_t row, + int64_t m, + int64_t world_size, + int32_t epoch) { + for (int64_t rank = 0; rank < world_size; ++rank) { + auto* rank_epoch_ptr = reinterpret_cast( + local_semaphore + (rank * m + row) * kSemaphoreVecWidth); + int64_t spins = 0; + while (atomicAdd(rank_epoch_ptr, 0U) < static_cast(epoch)) { + ++spins; + if (spins > kMaxSemaphorePollIters) { + asm volatile("trap;"); + } + } + } +} + +template +double timeKernelLaunchesMs( + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream, + LaunchFn&& launch_once) { + for (int64_t i = 0; i < warmup_iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) { + launch_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + + float total_ms = 0.0f; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / static_cast(iters); +} + // Naive fused kernel: // - A is row-sharded across ranks (axis M) // - each output row reads from its owner rank shard via remote pointers @@ -99,48 +176,27 @@ __global__ void fusedStagedThreadLoadSynchronizedKernel( int64_t m_per_rank, const __half* b_full, __half* c_out) { - constexpr int64_t kSemaphoreVecWidth = 4; - constexpr int64_t kMaxPollIters = 1LL << 26; const int32_t launch_epoch = launch_epoch_base + 1; for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { const int64_t owner_rank = row / m_per_rank; const int64_t local_row = row - owner_rank * m_per_rank; const __half* a_local = a_remote_shards[owner_rank]; - int32_t* my_ready_row_local = - ready_semaphore_local + (my_rank * m + row) * kSemaphoreVecWidth; - int32_t* my_done_row_local = - done_semaphore_local + (my_rank * m + row) * kSemaphoreVecWidth; - if (threadIdx.x == 0) { // Publish ready epoch to all ranks before remote reads. - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - my_ready_row_local[vec_i] = launch_epoch; - } - __threadfence_system(); - for (int64_t peer = 0; peer < world_size; ++peer) { - int32_t* peer_ready_row_remote = ready_semaphore_remote_ptrs[peer] + - (my_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - peer_ready_row_remote[vec_i] = launch_epoch; - } - } - __threadfence_system(); + publishEpochToAllRanks( + ready_semaphore_remote_ptrs, + ready_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); } __syncthreads(); if (threadIdx.x == 0) { - for (int64_t rank = 0; rank < world_size; ++rank) { - auto* rank_ready_epoch_ptr = reinterpret_cast( - ready_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); - int64_t spins = 0; - while (atomicAdd(rank_ready_epoch_ptr, 0U) < - static_cast(launch_epoch)) { - ++spins; - if (spins > kMaxPollIters) { - asm volatile("trap;"); - } - } - } + waitForEpochFromAllRanks( + ready_semaphore_local, row, m, world_size, launch_epoch); } __syncthreads(); @@ -163,34 +219,20 @@ __global__ void fusedStagedThreadLoadSynchronizedKernel( if (threadIdx.x == 0) { // Publish done epoch so producers can safely mutate next iteration. - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - my_done_row_local[vec_i] = launch_epoch; - } - __threadfence_system(); - for (int64_t peer = 0; peer < world_size; ++peer) { - int32_t* peer_done_row_remote = done_semaphore_remote_ptrs[peer] + - (my_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - peer_done_row_remote[vec_i] = launch_epoch; - } - } - __threadfence_system(); + publishEpochToAllRanks( + done_semaphore_remote_ptrs, + done_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); } __syncthreads(); if (threadIdx.x == 0) { - for (int64_t rank = 0; rank < world_size; ++rank) { - auto* rank_done_epoch_ptr = reinterpret_cast( - done_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); - int64_t spins = 0; - while (atomicAdd(rank_done_epoch_ptr, 0U) < - static_cast(launch_epoch)) { - ++spins; - if (spins > kMaxPollIters) { - asm volatile("trap;"); - } - } - } + waitForEpochFromAllRanks( + done_semaphore_local, row, m, world_size, launch_epoch); } __syncthreads(); } @@ -223,12 +265,11 @@ __global__ void fusedStagedMultimemKernel( for (int64_t vec_i = threadIdx.x; vec_i < n_vec; vec_i += blockDim.x) { const uint4 val = reinterpret_cast(a_local + local_row * k)[vec_i]; - char* dst_byte = reinterpret_cast(a_row_stage) + vec_i * 16; #if __CUDA_ARCH__ >= 900 asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" : - : "l"((void*)dst_byte), + : "l"((void*)(a_row_stage + vec_i * vec_elems)), "f"(__int_as_float(static_cast(val.x))), "f"(__int_as_float(static_cast(val.y))), "f"(__int_as_float(static_cast(val.z))), @@ -250,43 +291,25 @@ __global__ void fusedStagedMultimemKernel( // Cross-device barrier between stage-1 stores and stage-2 reads. // Each rank publishes one epoch for (my_rank,row), then waits until // all writer ranks have published the same epoch for that row. - constexpr int64_t kSemaphoreVecWidth = 4; - constexpr int64_t kMaxPollIters = 1LL << 26; const int32_t launch_epoch = launch_epoch_base + 1; - int32_t* my_row_local = stage_semaphore_local + - (my_rank * m + row) * kSemaphoreVecWidth; if (threadIdx.x == 0) { // Publish local completion to all peers' semaphore tensors. - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - my_row_local[vec_i] = launch_epoch; - } - __threadfence_system(); - for (int64_t peer = 0; peer < world_size; ++peer) { - int32_t* peer_row_remote = stage_semaphore_remote_ptrs[peer] + - (my_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - peer_row_remote[vec_i] = launch_epoch; - } - } - __threadfence_system(); + publishEpochToAllRanks( + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); } __syncthreads(); if (threadIdx.x == 0) { // Wait until every writer rank has published this row's epoch. - for (int64_t rank = 0; rank < world_size; ++rank) { - auto* rank_epoch_ptr = reinterpret_cast( - stage_semaphore_local + (rank * m + row) * kSemaphoreVecWidth); - int64_t spins = 0; - while (atomicAdd(rank_epoch_ptr, 0U) < - static_cast(launch_epoch)) { - ++spins; - if (spins > kMaxPollIters) { - asm volatile("trap;"); - } - } - } + waitForEpochFromAllRanks( + stage_semaphore_local, row, m, world_size, launch_epoch); } __syncthreads(); #else @@ -337,31 +360,7 @@ double timeFusedRemoteMatmulMs( fusedRemoteMatmulKernel<<>>( a_remote_shards, b_full, c_out, m, n, k, m_per_rank); }; - - for (int64_t i = 0; i < warmup_iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } double timeSeparatedAllgatherMatmulThreadLoadMs( @@ -385,30 +384,7 @@ double timeSeparatedAllgatherMatmulThreadLoadMs( fusedStagedThreadLoadKernel<<>>( a_remote_shards, a_gathered, m, n, k, m_per_rank, b_full, c_out); }; - - for (int64_t i = 0; i < warmup_iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( @@ -457,30 +433,7 @@ double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( c_out); ++launch_epoch_base; }; - - for (int64_t i = 0; i < warmup_iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } double timeSeparatedAllgatherMatmulMultimemMs( @@ -525,30 +478,7 @@ double timeSeparatedAllgatherMatmulMultimemMs( c_out); ++launch_epoch_base; }; - - for (int64_t i = 0; i < warmup_iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } } // namespace nvfuser From ab292b4fb32addf19dee0dd7e5acffe5c2a30952 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 05:20:35 -0800 Subject: [PATCH 11/23] refactor kernels and do p2p wait --- ..._multidevice_fused_remote_matmul_kernel.cu | 161 ++++++++++++------ 1 file changed, 106 insertions(+), 55 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 4b8307b6375..e9c60a7147e 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -42,6 +42,51 @@ __device__ inline void publishEpochToAllRanks( __threadfence_system(); } +__device__ inline void publishEpochToRank( + int32_t* remote_semaphore_for_target_rank, + int64_t writer_rank, + int64_t row, + int64_t m, + int32_t epoch) { + int32_t* row_remote = + remote_semaphore_for_target_rank + (writer_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + row_remote[vec_i] = epoch; + } + __threadfence_system(); +} + +__device__ inline void setLocalEpoch( + int32_t* local_semaphore, + int64_t writer_rank, + int64_t row, + int64_t m, + int32_t epoch) { + int32_t* row_local = + local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth; + for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { + row_local[vec_i] = epoch; + } + __threadfence_system(); +} + +__device__ inline void waitForEpochFromRank( + int32_t* local_semaphore, + int64_t row, + int64_t m, + int64_t writer_rank, + int32_t epoch) { + auto* rank_epoch_ptr = reinterpret_cast( + local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth); + int64_t spins = 0; + while (atomicAdd(rank_epoch_ptr, 0U) < static_cast(epoch)) { + ++spins; + if (spins > kMaxSemaphorePollIters) { + asm volatile("trap;"); + } + } +} + __device__ inline void waitForEpochFromAllRanks( int32_t* local_semaphore, int64_t row, @@ -182,21 +227,23 @@ __global__ void fusedStagedThreadLoadSynchronizedKernel( const int64_t local_row = row - owner_rank * m_per_rank; const __half* a_local = a_remote_shards[owner_rank]; if (threadIdx.x == 0) { - // Publish ready epoch to all ranks before remote reads. - publishEpochToAllRanks( - ready_semaphore_remote_ptrs, - ready_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); + // Owner publishes "ready"; non-owners wait only on owner readiness. + if (my_rank == owner_rank) { + publishEpochToAllRanks( + ready_semaphore_remote_ptrs, + ready_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); + } } __syncthreads(); - if (threadIdx.x == 0) { - waitForEpochFromAllRanks( - ready_semaphore_local, row, m, world_size, launch_epoch); + if (threadIdx.x == 0 && my_rank != owner_rank) { + waitForEpochFromRank( + ready_semaphore_local, row, m, owner_rank, launch_epoch); } __syncthreads(); @@ -218,19 +265,21 @@ __global__ void fusedStagedThreadLoadSynchronizedKernel( __syncthreads(); if (threadIdx.x == 0) { - // Publish done epoch so producers can safely mutate next iteration. - publishEpochToAllRanks( - done_semaphore_remote_ptrs, - done_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); + // Readers ack completion only to owner; owner waits on all readers. + if (my_rank == owner_rank) { + setLocalEpoch(done_semaphore_local, my_rank, row, m, launch_epoch); + } else { + publishEpochToRank( + done_semaphore_remote_ptrs[owner_rank], + my_rank, + row, + m, + launch_epoch); + } } __syncthreads(); - if (threadIdx.x == 0) { + if (threadIdx.x == 0 && my_rank == owner_rank) { waitForEpochFromAllRanks( done_semaphore_local, row, m, world_size, launch_epoch); } @@ -259,31 +308,32 @@ __global__ void fusedStagedMultimemKernel( const __half* a_local = a_remote_shards[owner_rank]; __half* a_row_stage = a_gathered_multicast + row * k; - // Stage 1: gather row into multicast staging buffer via 16-byte vectors. + // Owner materializes the multicast row; non-owners wait on owner readiness. constexpr int64_t vec_elems = 8; // 8 * half = 16 bytes const int64_t n_vec = k / vec_elems; - for (int64_t vec_i = threadIdx.x; vec_i < n_vec; vec_i += blockDim.x) { - const uint4 val = - reinterpret_cast(a_local + local_row * k)[vec_i]; - + if (my_rank == owner_rank) { + for (int64_t vec_i = threadIdx.x; vec_i < n_vec; vec_i += blockDim.x) { + const uint4 val = + reinterpret_cast(a_local + local_row * k)[vec_i]; #if __CUDA_ARCH__ >= 900 - asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"((void*)(a_row_stage + vec_i * vec_elems)), - "f"(__int_as_float(static_cast(val.x))), - "f"(__int_as_float(static_cast(val.y))), - "f"(__int_as_float(static_cast(val.z))), - "f"(__int_as_float(static_cast(val.w))) - : "memory"); + asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" + : + : "l"((void*)(a_row_stage + vec_i * vec_elems)), + "f"(__int_as_float(static_cast(val.x))), + "f"(__int_as_float(static_cast(val.y))), + "f"(__int_as_float(static_cast(val.z))), + "f"(__int_as_float(static_cast(val.w))) + : "memory"); #else - (void)val; - // Multimem path must never run on non-Hopper architectures. - asm volatile("trap;"); + (void)val; + // Multimem path must never run on non-Hopper architectures. + asm volatile("trap;"); #endif - } - for (int64_t kk = n_vec * vec_elems + threadIdx.x; kk < k; - kk += blockDim.x) { - a_row_stage[kk] = a_local[local_row * k + kk]; + } + for (int64_t kk = n_vec * vec_elems + threadIdx.x; kk < k; + kk += blockDim.x) { + a_row_stage[kk] = a_local[local_row * k + kk]; + } } __syncthreads(); @@ -294,22 +344,23 @@ __global__ void fusedStagedMultimemKernel( const int32_t launch_epoch = launch_epoch_base + 1; if (threadIdx.x == 0) { - // Publish local completion to all peers' semaphore tensors. - publishEpochToAllRanks( - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); + if (my_rank == owner_rank) { + publishEpochToAllRanks( + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); + } } __syncthreads(); - if (threadIdx.x == 0) { - // Wait until every writer rank has published this row's epoch. - waitForEpochFromAllRanks( - stage_semaphore_local, row, m, world_size, launch_epoch); + if (threadIdx.x == 0 && my_rank != owner_rank) { + // Non-owners only need owner readiness before reading multicast row. + waitForEpochFromRank( + stage_semaphore_local, row, m, owner_rank, launch_epoch); } __syncthreads(); #else From 1c53396b5e03864f2a87e1614358642ab9639dcd Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 06:03:44 -0800 Subject: [PATCH 12/23] add matmulTma support --- .../test_multidevice_fused_remote_matmul.cpp | 216 ++++++++++++++++- ..._multidevice_fused_remote_matmul_kernel.cu | 218 ++++++++++++++++++ 2 files changed, 428 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index ca85f894466..d445d9a9c31 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -22,6 +22,7 @@ #include "multidevice/cuda_p2p.h" #include "multidevice/ipc_handle.h" #include "multidevice/symmetric_tensor.h" +#include "runtime/matmul_tma.h" #include "tests/cpp/multidevice.h" namespace nvfuser { @@ -56,6 +57,34 @@ double timeSeparatedAllgatherMatmulThreadLoadMs( int64_t iters, cudaStream_t stream); +double timeNaiveRemoteMatmulCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +double timeSeparatedAllgatherMatmulThreadLoadCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( const __half* const* a_remote_shards, const __half* b_full, @@ -77,6 +106,26 @@ double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( int64_t iters, cudaStream_t stream); +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, @@ -96,6 +145,25 @@ double timeSeparatedAllgatherMatmulMultimemMs( int64_t iters, cudaStream_t stream); +double timeSeparatedAllgatherMatmulMultimemCutlassMs( + const __half* const* a_remote_shards, + __half* a_gathered_multicast_ptr, + const at::Tensor& a_gathered_local, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + namespace { // Implementations compared by this benchmark: @@ -121,12 +189,19 @@ namespace { // - stagedAllgatherComputeMultimem: // Same fused-kernel staged structure, but stage-1 writes use multimem store // instructions on a multicast pointer before compute. +// - *CutlassCompute variants: +// Keep communication path semantics and replace compute with Hopper CUTLASS +// TMA matmul. enum class DistributedMatmulImpl { naiveFusedKernel, + naiveFusedKernelCutlassCompute, baselinePytorchEagerNccl, baselinePytorchEagerCuda, stagedAllgatherComputeThreadLoad, + stagedAllgatherComputeThreadLoadCutlassCompute, stagedAllgatherComputeThreadLoadSynchronized, + stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute, + stagedAllgatherComputeMultimemCutlassCompute, stagedAllgatherComputeMultimem }; @@ -186,14 +261,22 @@ const char* implName(DistributedMatmulImpl impl) { switch (impl) { case DistributedMatmulImpl::naiveFusedKernel: return "naiveFusedKernel"; + case DistributedMatmulImpl::naiveFusedKernelCutlassCompute: + return "naiveFusedKernelCutlassCompute"; case DistributedMatmulImpl::baselinePytorchEagerNccl: return "baselinePytorchEagerNccl"; case DistributedMatmulImpl::baselinePytorchEagerCuda: return "baselinePytorchEagerCuda"; case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: return "stagedAllgatherComputeThreadLoad"; + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute: + return "stagedAllgatherComputeThreadLoadCutlassCompute"; case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: return "stagedAllgatherComputeThreadLoadSynchronized"; + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute: + return "stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute"; + case DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute: + return "stagedAllgatherComputeMultimemCutlassCompute"; case DistributedMatmulImpl::stagedAllgatherComputeMultimem: return "stagedAllgatherComputeMultimem"; } @@ -209,6 +292,15 @@ bool isMulticastSupported(int64_t device_id) { return result == CUDA_SUCCESS && is_multicast_supported != 0; } +bool canRunHopperCutlassCompute(const at::Tensor& a, const at::Tensor& b) { + if (!canRunMatmulTma(a, b)) { + return false; + } + auto* props = at::cuda::getDeviceProperties(a.get_device()); + // Restrict CUTLASS-compute benchmark variants to Hopper in this test. + return props->major == 9 && props->minor == 0; +} + template double benchmarkLoopMs( const BenchmarkConfig& config, @@ -313,7 +405,11 @@ BenchmarkResources initBenchmarkResources( } if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoad || - impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { + impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || + impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute || + impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + impl == + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { resources.a_gathered_threadload = at::empty( {m, k}, at::TensorOptions() @@ -322,7 +418,9 @@ BenchmarkResources initBenchmarkResources( .layout(at::kStrided)); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + impl == + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. resources.threadload_ready_semaphore = SymmetricTensor::allocate( {world_size, m, 4}, at::ScalarType::Int, communicator->device()); @@ -341,7 +439,8 @@ BenchmarkResources initBenchmarkResources( "fused_remote_matmul_threadload_done"); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || + impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) { resources.a_gathered_multimem = SymmetricTensor::allocate( {m, k}, at::ScalarType::Half, communicator->device()); resources.a_gathered_multimem_sym = @@ -474,6 +573,20 @@ double timeImplementationMs( }; return benchmarkLoopMs(config, communicator, stream, run_once); } + case DistributedMatmulImpl::naiveFusedKernelCutlassCompute: + return timeNaiveRemoteMatmulCutlassMs( + device_remote_ptrs, + a_gathered_threadload, + b_full_half, + c_out_half, + m, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::baselinePytorchEagerNccl: NVF_CHECK( nccl_backend != nullptr, @@ -515,6 +628,20 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute: + return timeSeparatedAllgatherMatmulThreadLoadCutlassMs( + device_remote_ptrs, + a_gathered_threadload, + b_full_half, + c_out_half, + m, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && @@ -544,6 +671,61 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute: + NVF_CHECK( + threadload_ready_semaphore_sym != nullptr && + threadload_done_semaphore_sym != nullptr && + threadload_ready_semaphore_remote_ptrs != nullptr && + threadload_done_semaphore_remote_ptrs != nullptr, + "stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute " + "requires semaphore resources."); + return timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( + device_remote_ptrs, + a_gathered_threadload, + threadload_ready_semaphore_remote_ptrs, + reinterpret_cast( + threadload_ready_semaphore_sym->localTensor().data_ptr()), + threadload_done_semaphore_remote_ptrs, + reinterpret_cast( + threadload_done_semaphore_sym->localTensor().data_ptr()), + my_rank, + world_size, + b_full_half, + c_out_half, + m, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); + case DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute: + NVF_CHECK( + a_gathered_multimem_sym != nullptr && + stage_semaphore_multimem_sym != nullptr && + stage_semaphore_remote_ptrs != nullptr, + "stagedAllgatherComputeMultimemCutlassCompute requires staging and " + "semaphore tensors."); + return timeSeparatedAllgatherMatmulMultimemCutlassMs( + device_remote_ptrs, + reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), + a_gathered_multimem, + stage_semaphore_remote_ptrs, + reinterpret_cast( + stage_semaphore_multimem_sym->localTensor().data_ptr()), + my_rank, + world_size, + b_full_half, + c_out_half, + m, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::stagedAllgatherComputeMultimem: NVF_CHECK( a_gathered_multimem_sym != nullptr && @@ -609,7 +791,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const int64_t my_rank = communicator_->deviceId(); const auto impl = GetParam(); - if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem && + if ((impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || + impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) && !isMulticastSupported(my_rank)) { GTEST_SKIP() << "Multicast is not supported on this device."; } @@ -673,7 +856,9 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized) { + if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + impl == + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && resources.threadload_done_semaphore_sym != nullptr, @@ -686,7 +871,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { resources.threadload_done_semaphore_sym->devicePeerPointers()); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem) { + if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || + impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) { NVF_CHECK( resources.stage_semaphore_multimem_sym != nullptr, "Missing staged multimem semaphore resources."); @@ -695,6 +881,20 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { resources.stage_semaphore_multimem_sym->devicePeerPointers()); } + const bool needs_cutlass_compute = impl == + DistributedMatmulImpl::naiveFusedKernelCutlassCompute || + impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute || + impl == + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute || + impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute; + if (needs_cutlass_compute && + !canRunHopperCutlassCompute( + a_gathered_threadload.defined() ? a_gathered_threadload : a_gathered_multimem, + b_full_half)) { + GTEST_SKIP() + << "CUTLASS-compute variants require Hopper SM90 with TMA support."; + } + auto run_implementation = [&](const BenchmarkConfig& config) { return timeImplementationMs( impl, @@ -761,10 +961,14 @@ INSTANTIATE_TEST_SUITE_P( FusedRemoteMatmulTest, testing::Values( DistributedMatmulImpl::naiveFusedKernel, + DistributedMatmulImpl::naiveFusedKernelCutlassCompute, DistributedMatmulImpl::baselinePytorchEagerNccl, DistributedMatmulImpl::baselinePytorchEagerCuda, DistributedMatmulImpl::stagedAllgatherComputeThreadLoad, + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute, DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized, + DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute, + DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute, DistributedMatmulImpl::stagedAllgatherComputeMultimem), [](const testing::TestParamInfo& info) { return implName(info.param); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index e9c60a7147e..f62a86e8f34 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -9,7 +9,10 @@ #include #include +#include + #include "cuda_utils.h" +#include "runtime/matmul_tma.h" namespace nvfuser { @@ -137,6 +140,25 @@ double timeKernelLaunchesMs( return static_cast(total_ms) / static_cast(iters); } +template +double timeCommThenCutlassMs( + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream, + const at::Tensor& a_comm, + const at::Tensor& b_full, + at::Tensor& c_out, + CommLaunchFn&& launch_comm_once) { + NVF_CHECK( + canRunMatmulTma(a_comm, b_full), + "CUTLASS TMA compute requires Hopper+ and compatible half inputs."); + auto launch_once = [&]() { + launch_comm_once(); + c_out.copy_(matmulTma(a_comm, b_full), /*non_blocking=*/true); + }; + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +} + // Naive fused kernel: // - A is row-sharded across ranks (axis M) // - each output row reads from its owner rank shard via remote pointers @@ -414,6 +436,48 @@ double timeFusedRemoteMatmulMs( return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } +double timeNaiveRemoteMatmulCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + const __half* b_ptr = + reinterpret_cast(b_full.data_ptr()); + __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); + __half* a_gathered_ptr = + reinterpret_cast<__half*>(a_gathered.data_ptr()); + auto launch_comm_once = [&]() { + // Keep communication as remote thread-load gather. + fusedStagedThreadLoadKernel<<>>( + a_remote_shards, + a_gathered_ptr, + m, + /*n=*/0, + k, + m_per_rank, + b_ptr, + c_ptr); + }; + return timeCommThenCutlassMs( + warmup_iters, + iters, + stream, + a_gathered, + b_full, + c_out, + launch_comm_once); +} + double timeSeparatedAllgatherMatmulThreadLoadMs( const __half* const* a_remote_shards, const __half* b_full, @@ -438,6 +502,47 @@ double timeSeparatedAllgatherMatmulThreadLoadMs( return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } +double timeSeparatedAllgatherMatmulThreadLoadCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + const __half* b_ptr = + reinterpret_cast(b_full.data_ptr()); + __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); + __half* a_gathered_ptr = + reinterpret_cast<__half*>(a_gathered.data_ptr()); + auto launch_comm_once = [&]() { + fusedStagedThreadLoadKernel<<>>( + a_remote_shards, + a_gathered_ptr, + m, + /*n=*/0, + k, + m_per_rank, + b_ptr, + c_ptr); + }; + return timeCommThenCutlassMs( + warmup_iters, + iters, + stream, + a_gathered, + b_full, + c_out, + launch_comm_once); +} + double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( const __half* const* a_remote_shards, const __half* b_full, @@ -487,6 +592,65 @@ double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + int64_t launch_epoch_base = 0; + const __half* b_ptr = + reinterpret_cast(b_full.data_ptr()); + __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); + __half* a_gathered_ptr = + reinterpret_cast<__half*>(a_gathered.data_ptr()); + auto launch_comm_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "ThreadLoad synchronized CUTLASS semaphore epoch overflow."); + fusedStagedThreadLoadSynchronizedKernel<<>>( + a_remote_shards, + a_gathered_ptr, + ready_semaphore_remote_ptrs, + ready_semaphore_local, + done_semaphore_remote_ptrs, + done_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + m, + /*n=*/0, + k, + m_per_rank, + b_ptr, + c_ptr); + ++launch_epoch_base; + }; + return timeCommThenCutlassMs( + warmup_iters, + iters, + stream, + a_gathered, + b_full, + c_out, + launch_comm_once); +} + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, @@ -532,4 +696,58 @@ double timeSeparatedAllgatherMatmulMultimemMs( return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } +double timeSeparatedAllgatherMatmulMultimemCutlassMs( + const __half* const* a_remote_shards, + __half* a_gathered_multicast_ptr, + const at::Tensor& a_gathered_local, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + int64_t launch_epoch_base = 0; + const __half* b_ptr = + reinterpret_cast(b_full.data_ptr()); + __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); + auto launch_comm_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "Multimem CUTLASS semaphore epoch overflow."); + fusedStagedMultimemKernel<<>>( + a_remote_shards, + a_gathered_multicast_ptr, + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + m, + /*n=*/0, + k, + m_per_rank, + b_ptr, + c_ptr); + ++launch_epoch_base; + }; + return timeCommThenCutlassMs( + warmup_iters, + iters, + stream, + a_gathered_local, + b_full, + c_out, + launch_comm_once); +} + } // namespace nvfuser From bea9cbf24bd9008d1edf78fd5651b228d887db14 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 06:08:18 -0800 Subject: [PATCH 13/23] avoid copy_ the output --- tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index f62a86e8f34..527eb9893fb 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -154,7 +154,8 @@ double timeCommThenCutlassMs( "CUTLASS TMA compute requires Hopper+ and compatible half inputs."); auto launch_once = [&]() { launch_comm_once(); - c_out.copy_(matmulTma(a_comm, b_full), /*non_blocking=*/true); + // Rebind output tensor to TMA matmul result (avoids an extra device copy). + c_out = matmulTma(a_comm, b_full); }; return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } From fd0647d1908223a985fc16363674841a570efc70 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 06:13:25 -0800 Subject: [PATCH 14/23] renaming --- .../test_multidevice_fused_remote_matmul.cpp | 108 +++++++++--------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index d445d9a9c31..32f5e306e46 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -179,14 +179,14 @@ namespace { // Same eager compute structure as the NCCL baseline, but communication uses // nvFuser's CUDA backend allgather primitives (post/wait with symmetric-memory // handles from cuda_p2p.h) before eager matmul_out. -// - stagedAllgatherComputeThreadLoad: +// - gpuAllgatherFusedNaiveCompute: // Single fused kernel with explicit internal stages: // (1) stage full A-row from remote pointers with regular thread loads, // (2) compute matmul for that row. -// - stagedAllgatherComputeThreadLoadSynchronized: +// - gpuSyncAllgatherFusedNaiveCompute: // Same thread-load staged kernel as above, but with fused cross-rank ready/done // semaphores so remote reads are safe even when producer shards are mutable. -// - stagedAllgatherComputeMultimem: +// - gpuAllgatherMultimemFusedNaiveCompute: // Same fused-kernel staged structure, but stage-1 writes use multimem store // instructions on a multicast pointer before compute. // - *CutlassCompute variants: @@ -197,12 +197,12 @@ enum class DistributedMatmulImpl { naiveFusedKernelCutlassCompute, baselinePytorchEagerNccl, baselinePytorchEagerCuda, - stagedAllgatherComputeThreadLoad, - stagedAllgatherComputeThreadLoadCutlassCompute, - stagedAllgatherComputeThreadLoadSynchronized, - stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute, - stagedAllgatherComputeMultimemCutlassCompute, - stagedAllgatherComputeMultimem + gpuAllgatherFusedNaiveCompute, + gpuAllgatherTmaCompute, + gpuSyncAllgatherFusedNaiveCompute, + gpuSyncAllgatherTmaCompute, + gpuAllgatherMultimemTmaCompute, + gpuAllgatherMultimemFusedNaiveCompute }; enum class TimeMeasurementMode { CudaEvents, CpuClock }; @@ -267,18 +267,18 @@ const char* implName(DistributedMatmulImpl impl) { return "baselinePytorchEagerNccl"; case DistributedMatmulImpl::baselinePytorchEagerCuda: return "baselinePytorchEagerCuda"; - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: - return "stagedAllgatherComputeThreadLoad"; - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute: - return "stagedAllgatherComputeThreadLoadCutlassCompute"; - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: - return "stagedAllgatherComputeThreadLoadSynchronized"; - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute: - return "stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute"; - case DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute: - return "stagedAllgatherComputeMultimemCutlassCompute"; - case DistributedMatmulImpl::stagedAllgatherComputeMultimem: - return "stagedAllgatherComputeMultimem"; + case DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute: + return "gpuAllgatherFusedNaiveCompute"; + case DistributedMatmulImpl::gpuAllgatherTmaCompute: + return "gpuAllgatherTmaCompute"; + case DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute: + return "gpuSyncAllgatherFusedNaiveCompute"; + case DistributedMatmulImpl::gpuSyncAllgatherTmaCompute: + return "gpuSyncAllgatherTmaCompute"; + case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: + return "gpuAllgatherMultimemTmaCompute"; + case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: + return "gpuAllgatherMultimemFusedNaiveCompute"; } NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); } @@ -404,12 +404,12 @@ BenchmarkResources initBenchmarkResources( resources.cuda_allgather_communication, resources.a_allgathered_half_cuda); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoad || + if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute || - impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || impl == - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { + DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { resources.a_gathered_threadload = at::empty( {m, k}, at::TensorOptions() @@ -418,9 +418,9 @@ BenchmarkResources initBenchmarkResources( .layout(at::kStrided)); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + if (impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || impl == - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { + DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. resources.threadload_ready_semaphore = SymmetricTensor::allocate( {world_size, m, 4}, at::ScalarType::Int, communicator->device()); @@ -439,8 +439,8 @@ BenchmarkResources initBenchmarkResources( "fused_remote_matmul_threadload_done"); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || - impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) { + if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) { resources.a_gathered_multimem = SymmetricTensor::allocate( {m, k}, at::ScalarType::Half, communicator->device()); resources.a_gathered_multimem_sym = @@ -613,7 +613,7 @@ double timeImplementationMs( config, communicator, stream); - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoad: + case DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute: return timeSeparatedAllgatherMatmulThreadLoadMs( device_remote_ptrs, reinterpret_cast(b_full_half.data_ptr()), @@ -628,7 +628,7 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute: + case DistributedMatmulImpl::gpuAllgatherTmaCompute: return timeSeparatedAllgatherMatmulThreadLoadCutlassMs( device_remote_ptrs, a_gathered_threadload, @@ -642,13 +642,13 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized: + case DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && threadload_done_semaphore_sym != nullptr && threadload_ready_semaphore_remote_ptrs != nullptr && threadload_done_semaphore_remote_ptrs != nullptr, - "stagedAllgatherComputeThreadLoadSynchronized requires semaphore resources."); + "gpuSyncAllgatherFusedNaiveCompute requires semaphore resources."); return timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( device_remote_ptrs, reinterpret_cast(b_full_half.data_ptr()), @@ -671,13 +671,13 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute: + case DistributedMatmulImpl::gpuSyncAllgatherTmaCompute: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && threadload_done_semaphore_sym != nullptr && threadload_ready_semaphore_remote_ptrs != nullptr && threadload_done_semaphore_remote_ptrs != nullptr, - "stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute " + "gpuSyncAllgatherTmaCompute " "requires semaphore resources."); return timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( device_remote_ptrs, @@ -700,12 +700,12 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute: + case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: NVF_CHECK( a_gathered_multimem_sym != nullptr && stage_semaphore_multimem_sym != nullptr && stage_semaphore_remote_ptrs != nullptr, - "stagedAllgatherComputeMultimemCutlassCompute requires staging and " + "gpuAllgatherMultimemTmaCompute requires staging and " "semaphore tensors."); return timeSeparatedAllgatherMatmulMultimemCutlassMs( device_remote_ptrs, @@ -726,11 +726,11 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::stagedAllgatherComputeMultimem: + case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: NVF_CHECK( a_gathered_multimem_sym != nullptr && stage_semaphore_multimem_sym != nullptr, - "stagedAllgatherComputeMultimem requires staging and semaphore tensors."); + "gpuAllgatherMultimemFusedNaiveCompute requires staging and semaphore tensors."); return timeSeparatedAllgatherMatmulMultimemMs( device_remote_ptrs, reinterpret_cast(b_full_half.data_ptr()), @@ -791,8 +791,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const int64_t my_rank = communicator_->deviceId(); const auto impl = GetParam(); - if ((impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || - impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) && + if ((impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) && !isMulticastSupported(my_rank)) { GTEST_SKIP() << "Multicast is not supported on this device."; } @@ -856,9 +856,9 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; - if (impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized || + if (impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || impl == - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute) { + DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && resources.threadload_done_semaphore_sym != nullptr, @@ -871,8 +871,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { resources.threadload_done_semaphore_sym->devicePeerPointers()); } - if (impl == DistributedMatmulImpl::stagedAllgatherComputeMultimem || - impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute) { + if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) { NVF_CHECK( resources.stage_semaphore_multimem_sym != nullptr, "Missing staged multimem semaphore resources."); @@ -883,10 +883,10 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const bool needs_cutlass_compute = impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute || + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || impl == - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute || - impl == DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute; + DistributedMatmulImpl::gpuSyncAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute; if (needs_cutlass_compute && !canRunHopperCutlassCompute( a_gathered_threadload.defined() ? a_gathered_threadload : a_gathered_multimem, @@ -964,12 +964,12 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::naiveFusedKernelCutlassCompute, DistributedMatmulImpl::baselinePytorchEagerNccl, DistributedMatmulImpl::baselinePytorchEagerCuda, - DistributedMatmulImpl::stagedAllgatherComputeThreadLoad, - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadCutlassCompute, - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronized, - DistributedMatmulImpl::stagedAllgatherComputeThreadLoadSynchronizedCutlassCompute, - DistributedMatmulImpl::stagedAllgatherComputeMultimemCutlassCompute, - DistributedMatmulImpl::stagedAllgatherComputeMultimem), + DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute, + DistributedMatmulImpl::gpuAllgatherTmaCompute, + DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute, + DistributedMatmulImpl::gpuSyncAllgatherTmaCompute, + DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute, + DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute), [](const testing::TestParamInfo& info) { return implName(info.param); }); From 7a14b2a40a99f5464a9fe4096874f7561f21912e Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 06:29:02 -0800 Subject: [PATCH 15/23] only keep strongly synchronizing implementations --- .../test_multidevice_fused_remote_matmul.cpp | 68 +++---------------- ..._multidevice_fused_remote_matmul_kernel.cu | 65 ------------------ 2 files changed, 10 insertions(+), 123 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 32f5e306e46..5b0b0266ae2 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -182,13 +182,10 @@ namespace { // - gpuAllgatherFusedNaiveCompute: // Single fused kernel with explicit internal stages: // (1) stage full A-row from remote pointers with regular thread loads, -// (2) compute matmul for that row. -// - gpuSyncAllgatherFusedNaiveCompute: -// Same thread-load staged kernel as above, but with fused cross-rank ready/done -// semaphores so remote reads are safe even when producer shards are mutable. +// (2) compute matmul for that row. Uses fused ready/done semaphores. // - gpuAllgatherMultimemFusedNaiveCompute: // Same fused-kernel staged structure, but stage-1 writes use multimem store -// instructions on a multicast pointer before compute. +// instructions on a multicast pointer before compute, with fused semaphores. // - *CutlassCompute variants: // Keep communication path semantics and replace compute with Hopper CUTLASS // TMA matmul. @@ -199,8 +196,6 @@ enum class DistributedMatmulImpl { baselinePytorchEagerCuda, gpuAllgatherFusedNaiveCompute, gpuAllgatherTmaCompute, - gpuSyncAllgatherFusedNaiveCompute, - gpuSyncAllgatherTmaCompute, gpuAllgatherMultimemTmaCompute, gpuAllgatherMultimemFusedNaiveCompute }; @@ -271,10 +266,6 @@ const char* implName(DistributedMatmulImpl impl) { return "gpuAllgatherFusedNaiveCompute"; case DistributedMatmulImpl::gpuAllgatherTmaCompute: return "gpuAllgatherTmaCompute"; - case DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute: - return "gpuSyncAllgatherFusedNaiveCompute"; - case DistributedMatmulImpl::gpuSyncAllgatherTmaCompute: - return "gpuSyncAllgatherTmaCompute"; case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: return "gpuAllgatherMultimemTmaCompute"; case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: @@ -406,10 +397,7 @@ BenchmarkResources initBenchmarkResources( if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || - impl == - DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { resources.a_gathered_threadload = at::empty( {m, k}, at::TensorOptions() @@ -418,9 +406,8 @@ BenchmarkResources initBenchmarkResources( .layout(at::kStrided)); } - if (impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || - impl == - DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { + if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. resources.threadload_ready_semaphore = SymmetricTensor::allocate( {world_size, m, 4}, at::ScalarType::Int, communicator->device()); @@ -614,41 +601,12 @@ double timeImplementationMs( communicator, stream); case DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute: - return timeSeparatedAllgatherMatmulThreadLoadMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - reinterpret_cast<__half*>(a_gathered_threadload.data_ptr()), - m, - n, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherTmaCompute: - return timeSeparatedAllgatherMatmulThreadLoadCutlassMs( - device_remote_ptrs, - a_gathered_threadload, - b_full_half, - c_out_half, - m, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && threadload_done_semaphore_sym != nullptr && threadload_ready_semaphore_remote_ptrs != nullptr && threadload_done_semaphore_remote_ptrs != nullptr, - "gpuSyncAllgatherFusedNaiveCompute requires semaphore resources."); + "gpuAllgatherFusedNaiveCompute requires semaphore resources."); return timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( device_remote_ptrs, reinterpret_cast(b_full_half.data_ptr()), @@ -671,14 +629,13 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); - case DistributedMatmulImpl::gpuSyncAllgatherTmaCompute: + case DistributedMatmulImpl::gpuAllgatherTmaCompute: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && threadload_done_semaphore_sym != nullptr && threadload_ready_semaphore_remote_ptrs != nullptr && threadload_done_semaphore_remote_ptrs != nullptr, - "gpuSyncAllgatherTmaCompute " - "requires semaphore resources."); + "gpuAllgatherTmaCompute requires semaphore resources."); return timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( device_remote_ptrs, a_gathered_threadload, @@ -856,9 +813,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; - if (impl == DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute || - impl == - DistributedMatmulImpl::gpuSyncAllgatherTmaCompute) { + if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && resources.threadload_done_semaphore_sym != nullptr, @@ -884,8 +840,6 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const bool needs_cutlass_compute = impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == - DistributedMatmulImpl::gpuSyncAllgatherTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute; if (needs_cutlass_compute && !canRunHopperCutlassCompute( @@ -966,8 +920,6 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::baselinePytorchEagerCuda, DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute, DistributedMatmulImpl::gpuAllgatherTmaCompute, - DistributedMatmulImpl::gpuSyncAllgatherFusedNaiveCompute, - DistributedMatmulImpl::gpuSyncAllgatherTmaCompute, DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute, DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute), [](const testing::TestParamInfo& info) { diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 527eb9893fb..278692753f2 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -479,71 +479,6 @@ double timeNaiveRemoteMatmulCutlassMs( launch_comm_once); } -double timeSeparatedAllgatherMatmulThreadLoadMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - - auto launch_once = [&]() { - fusedStagedThreadLoadKernel<<>>( - a_remote_shards, a_gathered, m, n, k, m_per_rank, b_full, c_out); - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeSeparatedAllgatherMatmulThreadLoadCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - const __half* b_ptr = - reinterpret_cast(b_full.data_ptr()); - __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); - __half* a_gathered_ptr = - reinterpret_cast<__half*>(a_gathered.data_ptr()); - auto launch_comm_once = [&]() { - fusedStagedThreadLoadKernel<<>>( - a_remote_shards, - a_gathered_ptr, - m, - /*n=*/0, - k, - m_per_rank, - b_ptr, - c_ptr); - }; - return timeCommThenCutlassMs( - warmup_iters, - iters, - stream, - a_gathered, - b_full, - c_out, - launch_comm_once); -} - double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( const __half* const* a_remote_shards, const __half* b_full, From 579a9458185f84e0bb47f41ae2fda0607d0adc49 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 12 Feb 2026 13:45:45 -0800 Subject: [PATCH 16/23] add fused tma impl --- .../test_multidevice_fused_remote_matmul.cpp | 175 +++++-- ..._multidevice_fused_remote_matmul_kernel.cu | 436 ++++++++++++++++++ 2 files changed, 576 insertions(+), 35 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 5b0b0266ae2..1fbfa4ebbc2 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -42,21 +42,6 @@ double timeFusedRemoteMatmulMs( int64_t iters, cudaStream_t stream); -double timeSeparatedAllgatherMatmulThreadLoadMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - double timeNaiveRemoteMatmulCutlassMs( const __half* const* a_remote_shards, at::Tensor& a_gathered, @@ -71,20 +56,6 @@ double timeNaiveRemoteMatmulCutlassMs( int64_t iters, cudaStream_t stream); -double timeSeparatedAllgatherMatmulThreadLoadCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( const __half* const* a_remote_shards, const __half* b_full, @@ -164,6 +135,53 @@ double timeSeparatedAllgatherMatmulMultimemCutlassMs( int64_t iters, cudaStream_t stream); +double timeNaiveRemoteMatmulFusedTmaMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( + const __half* const* a_remote_shards, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + +double timeSeparatedAllgatherMatmulMultimemFusedTmaMs( + const __half* const* a_remote_shards, + __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + namespace { // Implementations compared by this benchmark: @@ -192,11 +210,14 @@ namespace { enum class DistributedMatmulImpl { naiveFusedKernel, naiveFusedKernelCutlassCompute, + naiveFusedKernelFusedTma, baselinePytorchEagerNccl, baselinePytorchEagerCuda, gpuAllgatherFusedNaiveCompute, gpuAllgatherTmaCompute, + gpuAllgatherFusedTma, gpuAllgatherMultimemTmaCompute, + gpuAllgatherMultimemFusedTma, gpuAllgatherMultimemFusedNaiveCompute }; @@ -258,6 +279,8 @@ const char* implName(DistributedMatmulImpl impl) { return "naiveFusedKernel"; case DistributedMatmulImpl::naiveFusedKernelCutlassCompute: return "naiveFusedKernelCutlassCompute"; + case DistributedMatmulImpl::naiveFusedKernelFusedTma: + return "naiveFusedKernelFusedTma"; case DistributedMatmulImpl::baselinePytorchEagerNccl: return "baselinePytorchEagerNccl"; case DistributedMatmulImpl::baselinePytorchEagerCuda: @@ -266,8 +289,12 @@ const char* implName(DistributedMatmulImpl impl) { return "gpuAllgatherFusedNaiveCompute"; case DistributedMatmulImpl::gpuAllgatherTmaCompute: return "gpuAllgatherTmaCompute"; + case DistributedMatmulImpl::gpuAllgatherFusedTma: + return "gpuAllgatherFusedTma"; case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: return "gpuAllgatherMultimemTmaCompute"; + case DistributedMatmulImpl::gpuAllgatherMultimemFusedTma: + return "gpuAllgatherMultimemFusedTma"; case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: return "gpuAllgatherMultimemFusedNaiveCompute"; } @@ -397,7 +424,8 @@ BenchmarkResources initBenchmarkResources( if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { resources.a_gathered_threadload = at::empty( {m, k}, at::TensorOptions() @@ -407,7 +435,8 @@ BenchmarkResources initBenchmarkResources( } if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. resources.threadload_ready_semaphore = SymmetricTensor::allocate( {world_size, m, 4}, at::ScalarType::Int, communicator->device()); @@ -427,7 +456,8 @@ BenchmarkResources initBenchmarkResources( } if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) { resources.a_gathered_multimem = SymmetricTensor::allocate( {m, k}, at::ScalarType::Half, communicator->device()); resources.a_gathered_multimem_sym = @@ -574,6 +604,18 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::naiveFusedKernelFusedTma: + return timeNaiveRemoteMatmulFusedTmaMs( + device_remote_ptrs, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + m, + n, + k, + m_per_rank, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::baselinePytorchEagerNccl: NVF_CHECK( nccl_backend != nullptr, @@ -657,6 +699,32 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::gpuAllgatherFusedTma: + NVF_CHECK( + threadload_ready_semaphore_sym != nullptr && + threadload_done_semaphore_sym != nullptr && + threadload_ready_semaphore_remote_ptrs != nullptr && + threadload_done_semaphore_remote_ptrs != nullptr, + "gpuAllgatherFusedTma requires semaphore resources."); + return timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( + device_remote_ptrs, + threadload_ready_semaphore_remote_ptrs, + reinterpret_cast( + threadload_ready_semaphore_sym->localTensor().data_ptr()), + threadload_done_semaphore_remote_ptrs, + reinterpret_cast( + threadload_done_semaphore_sym->localTensor().data_ptr()), + my_rank, + world_size, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + m, + n, + k, + m_per_rank, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: NVF_CHECK( a_gathered_multimem_sym != nullptr && @@ -683,6 +751,29 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::gpuAllgatherMultimemFusedTma: + NVF_CHECK( + a_gathered_multimem_sym != nullptr && + stage_semaphore_multimem_sym != nullptr && + stage_semaphore_remote_ptrs != nullptr, + "gpuAllgatherMultimemFusedTma requires staging and semaphore tensors."); + return timeSeparatedAllgatherMatmulMultimemFusedTmaMs( + device_remote_ptrs, + reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), + stage_semaphore_remote_ptrs, + reinterpret_cast( + stage_semaphore_multimem_sym->localTensor().data_ptr()), + my_rank, + world_size, + reinterpret_cast(b_full_half.data_ptr()), + reinterpret_cast<__half*>(c_out_half.data_ptr()), + m, + n, + k, + m_per_rank, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: NVF_CHECK( a_gathered_multimem_sym != nullptr && @@ -749,7 +840,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const auto impl = GetParam(); if ((impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) && + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) && !isMulticastSupported(my_rank)) { GTEST_SKIP() << "Multicast is not supported on this device."; } @@ -814,7 +906,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && resources.threadload_done_semaphore_sym != nullptr, @@ -828,7 +921,8 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { } if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute) { + impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) { NVF_CHECK( resources.stage_semaphore_multimem_sym != nullptr, "Missing staged multimem semaphore resources."); @@ -841,6 +935,10 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { DistributedMatmulImpl::naiveFusedKernelCutlassCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute; + const bool needs_fused_tma_compute = impl == + DistributedMatmulImpl::naiveFusedKernelFusedTma || + impl == DistributedMatmulImpl::gpuAllgatherFusedTma || + impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma; if (needs_cutlass_compute && !canRunHopperCutlassCompute( a_gathered_threadload.defined() ? a_gathered_threadload : a_gathered_multimem, @@ -848,6 +946,10 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { GTEST_SKIP() << "CUTLASS-compute variants require Hopper SM90 with TMA support."; } + if (needs_fused_tma_compute && + !canRunHopperCutlassCompute(a_local_half, b_full_half)) { + GTEST_SKIP() << "FusedTma variants require Hopper SM90 support."; + } auto run_implementation = [&](const BenchmarkConfig& config) { return timeImplementationMs( @@ -916,11 +1018,14 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( DistributedMatmulImpl::naiveFusedKernel, DistributedMatmulImpl::naiveFusedKernelCutlassCompute, + DistributedMatmulImpl::naiveFusedKernelFusedTma, DistributedMatmulImpl::baselinePytorchEagerNccl, DistributedMatmulImpl::baselinePytorchEagerCuda, DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute, DistributedMatmulImpl::gpuAllgatherTmaCompute, + DistributedMatmulImpl::gpuAllgatherFusedTma, DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute, + DistributedMatmulImpl::gpuAllgatherMultimemFusedTma, DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute), [](const testing::TestParamInfo& info) { return implName(info.param); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 278692753f2..8171050a4cf 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -20,6 +21,7 @@ namespace { constexpr int64_t kSemaphoreVecWidth = 4; constexpr int64_t kMaxSemaphorePollIters = 1LL << 26; +constexpr int64_t kMmaTile = 16; __device__ inline void publishEpochToAllRanks( int32_t* const* remote_semaphore_ptrs, @@ -408,6 +410,319 @@ __global__ void fusedStagedMultimemKernel( } } +__global__ void fusedNaiveFusedTmaKernel( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank) { +#if __CUDA_ARCH__ >= 900 + if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { + return; + } + const int lane = static_cast(threadIdx.x); + const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; + const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; + if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { + return; + } + + __shared__ __half a_tile[kMmaTile * kMmaTile]; + __shared__ __half b_tile[kMmaTile * kMmaTile]; + using namespace nvcuda; + wmma::fragment acc_frag; + wmma::fill_fragment(acc_frag, 0.0f); + + for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + const int64_t row = row_base + r; + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + a_tile[idx] = a_local[local_row * k + (kk0 + c)]; + b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; + } + __syncwarp(); + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); + wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); + wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); + __syncwarp(); + } + + __shared__ float c_tile[kMmaTile * kMmaTile]; + wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); + __syncwarp(); + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); + } +#else + (void)a_remote_shards; + (void)b_full; + (void)c_out; + (void)m; + (void)n; + (void)k; + (void)m_per_rank; + asm volatile("trap;"); +#endif +} + +__global__ void fusedStagedThreadLoadSynchronizedFusedTmaKernel( + const __half* const* a_remote_shards, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + int32_t launch_epoch_base, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank) { +#if __CUDA_ARCH__ >= 900 + if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { + return; + } + const int lane = static_cast(threadIdx.x); + const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; + const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; + if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { + return; + } + const int32_t launch_epoch = launch_epoch_base + 1; + + for (int64_t r = 0; r < kMmaTile; ++r) { + const int64_t row = row_base + r; + const int64_t owner_rank = row / m_per_rank; + if (lane == 0) { + if (my_rank == owner_rank) { + publishEpochToAllRanks( + ready_semaphore_remote_ptrs, + ready_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); + } + } + __syncwarp(); + if (lane == 0 && my_rank != owner_rank) { + waitForEpochFromRank( + ready_semaphore_local, row, m, owner_rank, launch_epoch); + } + __syncwarp(); + } + + __shared__ __half a_tile[kMmaTile * kMmaTile]; + __shared__ __half b_tile[kMmaTile * kMmaTile]; + using namespace nvcuda; + wmma::fragment acc_frag; + wmma::fill_fragment(acc_frag, 0.0f); + for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + const int64_t row = row_base + r; + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + a_tile[idx] = a_local[local_row * k + (kk0 + c)]; + b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; + } + __syncwarp(); + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); + wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); + wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); + __syncwarp(); + } + + __shared__ float c_tile[kMmaTile * kMmaTile]; + wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); + __syncwarp(); + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); + } + __syncwarp(); + + for (int64_t r = 0; r < kMmaTile; ++r) { + const int64_t row = row_base + r; + const int64_t owner_rank = row / m_per_rank; + if (lane == 0) { + if (my_rank == owner_rank) { + setLocalEpoch(done_semaphore_local, my_rank, row, m, launch_epoch); + } else { + publishEpochToRank( + done_semaphore_remote_ptrs[owner_rank], + my_rank, + row, + m, + launch_epoch); + } + } + __syncwarp(); + if (lane == 0 && my_rank == owner_rank) { + waitForEpochFromAllRanks( + done_semaphore_local, row, m, world_size, launch_epoch); + } + __syncwarp(); + } +#else + (void)a_remote_shards; + (void)ready_semaphore_remote_ptrs; + (void)ready_semaphore_local; + (void)done_semaphore_remote_ptrs; + (void)done_semaphore_local; + (void)my_rank; + (void)world_size; + (void)launch_epoch_base; + (void)b_full; + (void)c_out; + (void)m; + (void)n; + (void)k; + (void)m_per_rank; + asm volatile("trap;"); +#endif +} + +__global__ void fusedStagedMultimemFusedTmaKernel( + const __half* const* a_remote_shards, + __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + int32_t launch_epoch_base, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank) { +#if __CUDA_ARCH__ >= 900 + if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { + return; + } + const int lane = static_cast(threadIdx.x); + const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; + const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; + if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { + return; + } + + constexpr int64_t vec_elems = 8; + for (int64_t r = 0; r < kMmaTile; ++r) { + const int64_t row = row_base + r; + const int64_t owner_rank = row / m_per_rank; + const int64_t local_row = row - owner_rank * m_per_rank; + const __half* a_local = a_remote_shards[owner_rank]; + __half* a_row_stage = a_gathered_multicast + row * k; + if (my_rank == owner_rank) { + const int64_t n_vec = k / vec_elems; + for (int64_t vec_i = lane; vec_i < n_vec; vec_i += warpSize) { + const uint4 val = + reinterpret_cast(a_local + local_row * k)[vec_i]; + asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" + : + : "l"((void*)(a_row_stage + vec_i * vec_elems)), + "f"(__int_as_float(static_cast(val.x))), + "f"(__int_as_float(static_cast(val.y))), + "f"(__int_as_float(static_cast(val.z))), + "f"(__int_as_float(static_cast(val.w))) + : "memory"); + } + for (int64_t kk = (k / vec_elems) * vec_elems + lane; kk < k; + kk += warpSize) { + a_row_stage[kk] = a_local[local_row * k + kk]; + } + } + __syncwarp(); + + const int32_t launch_epoch = launch_epoch_base + 1; + if (lane == 0) { + if (my_rank == owner_rank) { + publishEpochToAllRanks( + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + row, + m, + world_size, + launch_epoch); + } + } + __syncwarp(); + if (lane == 0 && my_rank != owner_rank) { + waitForEpochFromRank( + stage_semaphore_local, row, m, owner_rank, launch_epoch); + } + __syncwarp(); + } + + __shared__ __half a_tile[kMmaTile * kMmaTile]; + __shared__ __half b_tile[kMmaTile * kMmaTile]; + using namespace nvcuda; + wmma::fragment acc_frag; + wmma::fill_fragment(acc_frag, 0.0f); + for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + const int64_t row = row_base + r; + a_tile[idx] = a_gathered_multicast[row * k + (kk0 + c)]; + b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; + } + __syncwarp(); + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); + wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); + wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); + __syncwarp(); + } + __shared__ float c_tile[kMmaTile * kMmaTile]; + wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); + __syncwarp(); + for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { + const int64_t r = idx / kMmaTile; + const int64_t c = idx % kMmaTile; + c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); + } +#else + (void)a_remote_shards; + (void)a_gathered_multicast; + (void)stage_semaphore_remote_ptrs; + (void)stage_semaphore_local; + (void)my_rank; + (void)world_size; + (void)launch_epoch_base; + (void)b_full; + (void)c_out; + (void)m; + (void)n; + (void)k; + (void)m_per_rank; + asm volatile("trap;"); +#endif +} + } // namespace double timeFusedRemoteMatmulMs( @@ -686,4 +1001,125 @@ double timeSeparatedAllgatherMatmulMultimemCutlassMs( launch_comm_once); } +double timeNaiveRemoteMatmulFusedTmaMs( + const __half* const* a_remote_shards, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + NVF_CHECK( + m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, + "FusedTma kernels require M,N,K divisible by 16."); + const dim3 block(32); + const dim3 grid( + static_cast(n / kMmaTile), + static_cast(m / kMmaTile)); + auto launch_once = [&]() { + fusedNaiveFusedTmaKernel<<>>( + a_remote_shards, b_full, c_out, m, n, k, m_per_rank); + }; + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +} + +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( + const __half* const* a_remote_shards, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + NVF_CHECK( + m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, + "FusedTma kernels require M,N,K divisible by 16."); + const dim3 block(32); + const dim3 grid( + static_cast(n / kMmaTile), + static_cast(m / kMmaTile)); + int64_t launch_epoch_base = 0; + auto launch_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "FusedTma threadload semaphore epoch overflow."); + fusedStagedThreadLoadSynchronizedFusedTmaKernel<<>>( + a_remote_shards, + ready_semaphore_remote_ptrs, + ready_semaphore_local, + done_semaphore_remote_ptrs, + done_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + b_full, + c_out, + m, + n, + k, + m_per_rank); + ++launch_epoch_base; + }; + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +} + +double timeSeparatedAllgatherMatmulMultimemFusedTmaMs( + const __half* const* a_remote_shards, + __half* a_gathered_multicast, + int32_t* const* stage_semaphore_remote_ptrs, + int32_t* stage_semaphore_local, + int64_t my_rank, + int64_t world_size, + const __half* b_full, + __half* c_out, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + NVF_CHECK( + m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, + "FusedTma kernels require M,N,K divisible by 16."); + const dim3 block(32); + const dim3 grid( + static_cast(n / kMmaTile), + static_cast(m / kMmaTile)); + int64_t launch_epoch_base = 0; + auto launch_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "FusedTma multimem semaphore epoch overflow."); + fusedStagedMultimemFusedTmaKernel<<>>( + a_remote_shards, + a_gathered_multicast, + stage_semaphore_remote_ptrs, + stage_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + b_full, + c_out, + m, + n, + k, + m_per_rank); + ++launch_epoch_base; + }; + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +} + } // namespace nvfuser From a747786ad7dba69db6f2aaf855b85ffd82f36aa2 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Feb 2026 07:23:58 -0800 Subject: [PATCH 17/23] add variant with chunk-signaled async CUTLASS matmul --- .../test_multidevice_fused_remote_matmul.cpp | 97 ++++++++++++++++++- ..._multidevice_fused_remote_matmul_kernel.cu | 81 ++++++++++++++++ 2 files changed, 175 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 1fbfa4ebbc2..5ce78686109 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -97,6 +97,28 @@ double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( int64_t iters, cudaStream_t stream); +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + at::Tensor& a_chunk_signals, + int64_t a_chunk_pivot, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream); + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, @@ -215,6 +237,7 @@ enum class DistributedMatmulImpl { baselinePytorchEagerCuda, gpuAllgatherFusedNaiveCompute, gpuAllgatherTmaCompute, + gpuAllgatherAsyncSchedulerTmaCompute, gpuAllgatherFusedTma, gpuAllgatherMultimemTmaCompute, gpuAllgatherMultimemFusedTma, @@ -271,6 +294,7 @@ struct BenchmarkResources { std::unique_ptr threadload_ready_semaphore_sym; at::Tensor threadload_done_semaphore; std::unique_ptr threadload_done_semaphore_sym; + at::Tensor async_chunk_signals; }; const char* implName(DistributedMatmulImpl impl) { @@ -289,6 +313,8 @@ const char* implName(DistributedMatmulImpl impl) { return "gpuAllgatherFusedNaiveCompute"; case DistributedMatmulImpl::gpuAllgatherTmaCompute: return "gpuAllgatherTmaCompute"; + case DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute: + return "gpuAllgatherAsyncSchedulerTmaCompute"; case DistributedMatmulImpl::gpuAllgatherFusedTma: return "gpuAllgatherFusedTma"; case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: @@ -319,6 +345,19 @@ bool canRunHopperCutlassCompute(const at::Tensor& a, const at::Tensor& b) { return props->major == 9 && props->minor == 0; } +bool canRunAsyncSchedulerCompute( + const at::Tensor& a, + const at::Tensor& b, + const at::Tensor& a_chunk_signals) { + if (!canRunHopperCutlassCompute(a, b)) { + return false; + } + return a_chunk_signals.defined() && a_chunk_signals.is_cuda() && + a_chunk_signals.scalar_type() == at::ScalarType::Int && + a_chunk_signals.dim() == 1 && a_chunk_signals.numel() > 0 && + a.size(0) % a_chunk_signals.numel() == 0; +} + template double benchmarkLoopMs( const BenchmarkConfig& config, @@ -425,6 +464,7 @@ BenchmarkResources initBenchmarkResources( if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { resources.a_gathered_threadload = at::empty( {m, k}, @@ -436,6 +476,7 @@ BenchmarkResources initBenchmarkResources( if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. resources.threadload_ready_semaphore = SymmetricTensor::allocate( @@ -455,6 +496,12 @@ BenchmarkResources initBenchmarkResources( "fused_remote_matmul_threadload_done"); } + if (impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute) { + resources.async_chunk_signals = at::zeros( + {world_size}, + at::TensorOptions().dtype(at::kInt).device(communicator->device())); + } + if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) { @@ -546,6 +593,7 @@ double timeImplementationMs( at::Tensor& a_allgathered_half_cuda, at::Tensor& a_gathered_threadload, at::Tensor& a_gathered_multimem, + at::Tensor& async_chunk_signals, SymmetricTensor* threadload_ready_semaphore_sym, SymmetricTensor* threadload_done_semaphore_sym, SymmetricTensor* a_gathered_multimem_sym, @@ -699,6 +747,39 @@ double timeImplementationMs( config.warmup_iters, config.iters, stream); + case DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute: + NVF_CHECK( + threadload_ready_semaphore_sym != nullptr && + threadload_done_semaphore_sym != nullptr && + threadload_ready_semaphore_remote_ptrs != nullptr && + threadload_done_semaphore_remote_ptrs != nullptr, + "gpuAllgatherAsyncSchedulerTmaCompute requires semaphore resources."); + NVF_CHECK( + async_chunk_signals.defined(), + "gpuAllgatherAsyncSchedulerTmaCompute requires chunk signals."); + return timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( + device_remote_ptrs, + a_gathered_threadload, + threadload_ready_semaphore_remote_ptrs, + reinterpret_cast( + threadload_ready_semaphore_sym->localTensor().data_ptr()), + threadload_done_semaphore_remote_ptrs, + reinterpret_cast( + threadload_done_semaphore_sym->localTensor().data_ptr()), + my_rank, + world_size, + b_full_half, + c_out_half, + async_chunk_signals, + /*a_chunk_pivot=*/my_rank, + m, + k, + m_per_rank, + staged_block_threads, + staged_grid_blocks, + config.warmup_iters, + config.iters, + stream); case DistributedMatmulImpl::gpuAllgatherFusedTma: NVF_CHECK( threadload_ready_semaphore_sym != nullptr && @@ -886,6 +967,7 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { at::Tensor a_allgathered_half_cuda; at::Tensor a_gathered_threadload; at::Tensor a_gathered_multimem; + at::Tensor async_chunk_signals; c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( /*isHighPriority=*/false, static_cast(communicator_->device().index())); @@ -901,12 +983,14 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_allgathered_half_cuda = resources.a_allgathered_half_cuda; a_gathered_threadload = resources.a_gathered_threadload; a_gathered_multimem = resources.a_gathered_multimem; + async_chunk_signals = resources.async_chunk_signals; int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { NVF_CHECK( resources.threadload_ready_semaphore_sym != nullptr && @@ -934,15 +1018,20 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { const bool needs_cutlass_compute = impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || + impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute; const bool needs_fused_tma_compute = impl == DistributedMatmulImpl::naiveFusedKernelFusedTma || impl == DistributedMatmulImpl::gpuAllgatherFusedTma || impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma; if (needs_cutlass_compute && - !canRunHopperCutlassCompute( - a_gathered_threadload.defined() ? a_gathered_threadload : a_gathered_multimem, - b_full_half)) { + !(impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute + ? canRunAsyncSchedulerCompute( + a_gathered_threadload, b_full_half, async_chunk_signals) + : canRunHopperCutlassCompute( + a_gathered_threadload.defined() ? a_gathered_threadload + : a_gathered_multimem, + b_full_half))) { GTEST_SKIP() << "CUTLASS-compute variants require Hopper SM90 with TMA support."; } @@ -964,6 +1053,7 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { a_allgathered_half_cuda, a_gathered_threadload, a_gathered_multimem, + async_chunk_signals, resources.threadload_ready_semaphore_sym.get(), resources.threadload_done_semaphore_sym.get(), resources.a_gathered_multimem_sym.get(), @@ -1023,6 +1113,7 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::baselinePytorchEagerCuda, DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute, DistributedMatmulImpl::gpuAllgatherTmaCompute, + DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute, DistributedMatmulImpl::gpuAllgatherFusedTma, DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute, DistributedMatmulImpl::gpuAllgatherMultimemFusedTma, diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 8171050a4cf..be7da99bc5c 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -162,6 +162,29 @@ double timeCommThenCutlassMs( return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); } +void asyncAllgatherMatmulTmaOutLocal( + at::Tensor& out, + const at::Tensor& a, + const at::Tensor& b, + const at::Tensor& a_chunk_signals, + int64_t a_chunk_pivot) { + NVF_CHECK( + canRunMatmulTma(a, b), + "Async CUTLASS compute requires TMA-compatible A/B."); + NVF_CHECK( + a_chunk_signals.defined() && a_chunk_signals.is_cuda() && + a_chunk_signals.scalar_type() == at::ScalarType::Int && + a_chunk_signals.dim() == 1 && a_chunk_signals.numel() > 0, + "Async CUTLASS compute requires CUDA int32 chunk signals."); + NVF_CHECK( + a.size(0) % a_chunk_signals.numel() == 0, + "A rows must be divisible by chunk-signals length."); + NVF_CHECK( + a_chunk_pivot >= 0 && a_chunk_pivot <= a_chunk_signals.numel(), + "Chunk pivot must be in [0, num_chunks]."); + out.copy_(matmulTma(a, b)); +} + // Naive fused kernel: // - A is row-sharded across ranks (axis M) // - each output row reads from its owner rank shard via remote pointers @@ -902,6 +925,64 @@ double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( launch_comm_once); } +double timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( + const __half* const* a_remote_shards, + at::Tensor& a_gathered, + int32_t* const* ready_semaphore_remote_ptrs, + int32_t* ready_semaphore_local, + int32_t* const* done_semaphore_remote_ptrs, + int32_t* done_semaphore_local, + int64_t my_rank, + int64_t world_size, + const at::Tensor& b_full, + at::Tensor& c_out, + at::Tensor& a_chunk_signals, + int64_t a_chunk_pivot, + int64_t m, + int64_t k, + int64_t m_per_rank, + int64_t block_threads, + int64_t grid_blocks, + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream) { + const dim3 block(static_cast(block_threads)); + const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); + int64_t launch_epoch_base = 0; + const __half* b_ptr = + reinterpret_cast(b_full.data_ptr()); + __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); + __half* a_gathered_ptr = + reinterpret_cast<__half*>(a_gathered.data_ptr()); + auto launch_once = [&]() { + NVF_CHECK( + launch_epoch_base < std::numeric_limits::max(), + "ThreadLoad synchronized async CUTLASS semaphore epoch overflow."); + fusedStagedThreadLoadSynchronizedKernel<<>>( + a_remote_shards, + a_gathered_ptr, + ready_semaphore_remote_ptrs, + ready_semaphore_local, + done_semaphore_remote_ptrs, + done_semaphore_local, + my_rank, + world_size, + static_cast(launch_epoch_base), + m, + /*n=*/0, + k, + m_per_rank, + b_ptr, + c_ptr); + + a_chunk_signals.fill_(1); + asyncAllgatherMatmulTmaOutLocal( + c_out, a_gathered, b_full, a_chunk_signals, a_chunk_pivot); + ++launch_epoch_base; + }; + return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +} + double timeSeparatedAllgatherMatmulMultimemMs( const __half* const* a_remote_shards, const __half* b_full, From 1b9213a44d11409dd8bd4e18b4411e6c74cea368 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Feb 2026 07:11:52 -0800 Subject: [PATCH 18/23] remove local matmul perf test --- CMakeLists.txt | 1 - tests/cpp/test_matmul_perf.cpp | 163 --------------------------------- 2 files changed, 164 deletions(-) delete mode 100644 tests/cpp/test_matmul_perf.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f7c11037a5a..2f18130dba3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1080,7 +1080,6 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_translate_mma.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul_aten_evaluation.cpp - ${NVFUSER_ROOT}/tests/cpp/test_matmul_perf.cpp # ${NVFUSER_ROOT}/tests/cpp/test_matmul_sass.cpp ${NVFUSER_ROOT}/tests/cpp/test_matmul_scheduler.cpp ${NVFUSER_ROOT}/tests/cpp/test_mma.cpp diff --git a/tests/cpp/test_matmul_perf.cpp b/tests/cpp/test_matmul_perf.cpp deleted file mode 100644 index c656091090a..00000000000 --- a/tests/cpp/test_matmul_perf.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -#include -#include -#include -#include - -#include "exceptions.h" -#include "fusion.h" -#include "fusion_guard.h" -#include "mma_type.h" -#include "ops/all_ops.h" -#include "preseg_passes/pre_segmenter.h" -#include "runtime/fusion_executor_cache.h" -#include "runtime/matmul_tma.h" -#include "scheduler/matmul.h" -#include "tests/cpp/utils.h" - -#if defined(NVFUSER_ENABLE_CUTLASS) -#if !defined(__CUDACC_VER_MAJOR__) -#define __CUDACC_VER_MAJOR__ 13 -#define __CUDACC_VER_MINOR__ 0 -#endif -#include "cutlass/arch/config.h" -#endif - -namespace nvfuser { - -namespace { - -struct MatmulProblem { - int64_t m; - int64_t n; - int64_t k; -}; - -std::unique_ptr buildMatmulFusion(DataType dtype) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto a = makeContigTensor(2, dtype); - auto b = makeContigTensor(2, dtype); - fusion->addInput(a); - fusion->addInput(b); - - auto layout = MmaLayout::TT; - auto a_canon = canonicalizeInputToBMNK(a, layout, MmaOperand::A); - auto b_canon = canonicalizeInputToBMNK(b, layout, MmaOperand::B); - auto c = fusedMultiplySum(a_canon, b_canon, {-1}); - auto d = castOp(dtype, c); - - fusion->addOutput(d); - OptimizationPass::runPass(fusion.get()); - return fusion; -} - -template -double timeMs(int warmup_iters, int iters, Fn&& fn) { - for (int i = 0; i < warmup_iters; ++i) { - fn(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); - - auto start = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < iters; ++i) { - fn(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); - auto end = std::chrono::high_resolution_clock::now(); - - std::chrono::duration elapsed = end - start; - return elapsed.count() / static_cast(iters); -} - -void printResult( - const std::string& label, - const MatmulProblem& problem, - double ms_per_iter) { - const double flops = 2.0 * static_cast(problem.m) * - static_cast(problem.n) * static_cast(problem.k); - const double gflops = flops / (ms_per_iter * 1.0e6); - std::cout << label << " M=" << problem.m << " N=" << problem.n - << " K=" << problem.k << " : " << ms_per_iter << " ms, " - << gflops << " GFLOPs" << std::endl; -} - -} // namespace - -TEST(MatmulPerfTest, CompareImplementations) { - if (!deviceMajorMinorCheck(9, 0)) { - GTEST_SKIP() << "Requires SM90 (Hopper)."; - } -#if !defined(NVFUSER_ENABLE_CUTLASS) - GTEST_SKIP() << "CUTLASS support is disabled."; -#endif -#if defined(NVFUSER_ENABLE_CUTLASS) && !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - GTEST_SKIP() << "CUTLASS SM90 support is unavailable."; -#endif - - constexpr int warmup_iters = 100; - constexpr int iters = 1000; - - std::vector problems{ - {1024, 1024, 1024}, - {2048, 2048, 2048}, - {4096, 4096, 4096}, - }; - - std::vector dtypes{ - at::ScalarType::Half, - at::ScalarType::BFloat16, - }; - - for (auto dtype : dtypes) { - for (const auto& problem : problems) { - at::cuda::CUDAGuard device_guard{0}; - auto options = - at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - auto a = at::randn({problem.m, problem.k}, options); - auto b = at::randn({problem.k, problem.n}, options); - - auto ref = at::matmul(a, b); - - auto tma_out = matmulTma(a, b); - EXPECT_TRUE(at::allclose(tma_out, ref, 1e-2, 1e-2)); - - auto fusion = buildMatmulFusion( - dtype == at::ScalarType::Half ? DataType::Half : DataType::BFloat16); - FusionExecutorCache executor_cache(std::move(fusion)); - auto fuser_out = executor_cache.runFusionWithInputs({a, b}); - auto fuser_tensor = fuser_out[0].as(); - EXPECT_TRUE(at::allclose(fuser_tensor, ref, 1e-2, 1e-2)); - - double torch_ms = timeMs(warmup_iters, iters, [&]() { - auto out = at::matmul(a, b); - (void)out; - }); - printResult("torch", problem, torch_ms); - - double fuser_ms = timeMs(warmup_iters, iters, [&]() { - auto out = executor_cache.runFusionWithInputs({a, b}); - (void)out; - }); - printResult("nvfuser", problem, fuser_ms); - - double tma_ms = timeMs(warmup_iters, iters, [&]() { - auto out = matmulTma(a, b); - (void)out; - }); - printResult("tma", problem, tma_ms); - } - } -} - -} // namespace nvfuser From 4d1bad0e6096bbf4b29ed25c974aef7e559d3728 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Feb 2026 08:19:36 -0800 Subject: [PATCH 19/23] major rewriting for clarity --- CMakeLists.txt | 8 - csrc/runtime/matmul_tma.cu | 317 ---- csrc/runtime/matmul_tma.h | 22 - .../test_multidevice_fused_remote_matmul.cpp | 1401 ++++----------- ..._multidevice_fused_remote_matmul_kernel.cu | 1587 +++++------------ 5 files changed, 862 insertions(+), 2473 deletions(-) delete mode 100644 csrc/runtime/matmul_tma.cu delete mode 100644 csrc/runtime/matmul_tma.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f18130dba3..4e7accfb4ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -297,7 +297,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/runtime/fusion_cache_utils.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_executor_cache.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_kernel_runtime.cpp - ${NVFUSER_SRCS_DIR}/runtime/matmul_tma.cu ${NVFUSER_SRCS_DIR}/scheduler/cache_policy_refiner.cpp ${NVFUSER_SRCS_DIR}/scheduler/cutlass.cpp ${NVFUSER_SRCS_DIR}/scheduler/heuristic.cpp @@ -400,13 +399,6 @@ endif() # "private" (not installed) static library. add_library(codegen_internal OBJECT ${NVFUSER_SRCS}) -# Special handling for CUDA files that include CUTLASS headers -# nvcc doesn't support the same flags as gcc/clang, so we need to wrap them -set_source_files_properties( - ${NVFUSER_SRCS_DIR}/runtime/matmul_tma.cu - PROPERTIES - COMPILE_OPTIONS "-Xcompiler=-Wall;-Xcompiler=-Wno-unused-function;-Xcompiler=-Werror" -) if(NOT MSVC) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/csrc/runtime/matmul_tma.cu b/csrc/runtime/matmul_tma.cu deleted file mode 100644 index 18393b000d1..00000000000 --- a/csrc/runtime/matmul_tma.cu +++ /dev/null @@ -1,317 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include - -#include - -#include -#include - -#if defined(NVFUSER_ENABLE_CUTLASS) -#if !defined(__CUDACC_VER_MAJOR__) -#define __CUDACC_VER_MAJOR__ 13 -#define __CUDACC_VER_MINOR__ 0 -#endif -#include "cutlass/arch/config.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/util/packed_stride.hpp" -#endif - -namespace nvfuser { - -namespace { - -bool hasValidTmaInputShape(const at::Tensor& a, const at::Tensor& b) { - if (!a.defined() || !b.defined()) { - return false; - } - if (!a.is_cuda() || !b.is_cuda()) { - return false; - } - if (a.dim() != 2 || b.dim() != 2) { - return false; - } - if (a.scalar_type() != b.scalar_type()) { - return false; - } - if (!(a.scalar_type() == at::ScalarType::Half || - a.scalar_type() == at::ScalarType::BFloat16)) { - return false; - } - if (!a.is_contiguous() || !b.is_contiguous()) { - return false; - } - if (a.size(1) != b.size(0)) { - return false; - } - if (a.get_device() != b.get_device()) { - return false; - } - // CUTLASS TMA mainloop requires alignment-compatible K/N extents. - constexpr int64_t kAlignment = 8; - if (a.size(1) % kAlignment != 0 || b.size(1) % kAlignment != 0) { - return false; - } - return true; -} - -#if defined(NVFUSER_ENABLE_CUTLASS) && defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -using namespace cute; - -template -struct MatmulTmaSm90 { - using ElementA = ElementT; - using ElementB = ElementT; - using ElementC = ElementT; - using ElementD = ElementT; - - using LayoutATag = cutlass::layout::RowMajor; - using LayoutBTag = cutlass::layout::RowMajor; - using LayoutCTag = cutlass::layout::RowMajor; - using LayoutDTag = cutlass::layout::RowMajor; - - static constexpr int kAlignmentA = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAlignmentB = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAlignmentC = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAlignmentD = - 128 / cutlass::sizeof_bits::value; - - using ElementAccumulator = float; - using ArchTag = cutlass::arch::Sm90; - using OperatorClass = cutlass::arch::OpClassTensorOp; - - using MmaTileShape = Shape<_128, _128, _64>; - using ClusterShape = Shape<_1, _1, _1>; - using PerSmTileShape_MNK = Shape<_128, _128, _64>; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - PerSmTileShape_MNK, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementAccumulator, - ElementC, - LayoutCTag, - kAlignmentC, - ElementD, - LayoutDTag, - kAlignmentD, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - LayoutATag, - kAlignmentA, - ElementB, - LayoutBTag, - kAlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; -}; - -template -typename MatmulTmaSm90::Gemm::Arguments buildArguments( - at::Tensor& output, - const at::Tensor& a, - const at::Tensor& b, - int64_t m, - int64_t n, - int64_t k) { - using Config = MatmulTmaSm90; - using ElementA = typename Config::ElementA; - using ElementB = typename Config::ElementB; - using ElementD = typename Config::ElementD; - using StrideA = typename Config::StrideA; - using StrideB = typename Config::StrideB; - using StrideC = typename Config::StrideC; - using StrideD = typename Config::StrideD; - - auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, {static_cast(m), static_cast(k), 1}); - auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, {static_cast(k), static_cast(n), 1}); - auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, {static_cast(m), static_cast(n), 1}); - auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, {static_cast(m), static_cast(n), 1}); - - typename Config::GemmKernel::MainloopArguments mainloop_args{ - static_cast(a.data_ptr()), - stride_a, - static_cast(b.data_ptr()), - stride_b}; - - typename Config::GemmKernel::EpilogueArguments epilogue_args{ - {}, // epilogue.thread - nullptr, - stride_c, - static_cast(output.data_ptr()), - stride_d}; - - typename Config::GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemm, - {static_cast(m), static_cast(n), static_cast(k), 1}, - mainloop_args, - epilogue_args}; - - return args; -} - -template -void runMatmulSm90( - at::Tensor& output, - const at::Tensor& a, - const at::Tensor& b, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - using Config = MatmulTmaSm90; - typename Config::Gemm gemm; - auto args = buildArguments(output, a, b, m, n, k); - - size_t workspace_size = Config::Gemm::get_workspace_size(args); - auto workspace_options = - at::TensorOptions().dtype(at::kByte).device(a.device()); - auto workspace = - at::empty({static_cast(workspace_size)}, workspace_options); - - auto can_implement_status = gemm.can_implement(args); - NVF_CHECK( - can_implement_status == cutlass::Status::kSuccess, - "TMA GEMM cannot be implemented for the given inputs."); - - auto status = gemm.initialize(args, workspace.data_ptr(), stream); - NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM."); - - status = gemm.run( - args, - workspace.data_ptr(), - stream, - /*cuda_adapter=*/nullptr, - /*launch_with_pdl=*/true); - NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM."); -} -#else -template -void runMatmulSm90( - at::Tensor& output, - const at::Tensor& a, - const at::Tensor& b, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - NVF_THROW("CUTLASS SM90 support is required for TMA matmul."); -} -#endif // NVFUSER_ENABLE_CUTLASS && CUTLASS_ARCH_MMA_SM90_SUPPORTED - -void validateInputs(const at::Tensor& a, const at::Tensor& b) { - NVF_CHECK(a.is_cuda(), "Expected CUDA tensor for operand A."); - NVF_CHECK(b.is_cuda(), "Expected CUDA tensor for operand B."); - NVF_CHECK(a.dim() == 2, "Operand A must be rank-2."); - NVF_CHECK(b.dim() == 2, "Operand B must be rank-2."); - NVF_CHECK( - a.scalar_type() == b.scalar_type(), - "Operands A and B must have the same dtype."); - NVF_CHECK( - a.scalar_type() == at::ScalarType::Half || - a.scalar_type() == at::ScalarType::BFloat16, - "Only Half and BFloat16 are supported."); - NVF_CHECK( - a.is_contiguous() && b.is_contiguous(), - "Operands must be contiguous row-major tensors."); - NVF_CHECK( - a.size(1) == b.size(0), - "Mismatched matmul dimensions: A[K] must match B[K]."); - NVF_CHECK( - a.get_device() == b.get_device(), - "Operands must be on the same CUDA device."); - - constexpr int64_t kAlignment = 8; - NVF_CHECK( - a.size(1) % kAlignment == 0, - "K dimension must be a multiple of 8 for TMA alignment."); - NVF_CHECK( - b.size(1) % kAlignment == 0, - "N dimension must be a multiple of 8 for TMA alignment."); -} - -} // namespace - -at::Tensor matmulTma(const at::Tensor& a, const at::Tensor& b) { - validateInputs(a, b); - at::cuda::CUDAGuard device_guard{a.device()}; - auto* props = at::cuda::getDeviceProperties(a.get_device()); - NVF_CHECK( - props->major >= 9, - "TMA matmul requires SM90 (Hopper) or newer."); - - const int64_t m = a.size(0); - const int64_t n = b.size(1); - const int64_t k = a.size(1); - - auto options = - at::TensorOptions().dtype(a.scalar_type()).device(a.device()); - at::Tensor output = at::empty({m, n}, options); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); - -#if defined(NVFUSER_ENABLE_CUTLASS) - if (a.scalar_type() == at::ScalarType::Half) { - runMatmulSm90(output, a, b, m, n, k, stream); - } else { - runMatmulSm90(output, a, b, m, n, k, stream); - } -#else - NVF_THROW("CUTLASS support is required for TMA matmul."); -#endif - - return output; -} - -bool canRunMatmulTma(const at::Tensor& a, const at::Tensor& b) { - if (!hasValidTmaInputShape(a, b)) { - return false; - } - -#if !defined(NVFUSER_ENABLE_CUTLASS) || \ - !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - return false; -#else - auto* props = at::cuda::getDeviceProperties(a.get_device()); - return props->major >= 9; -#endif -} - -} // namespace nvfuser diff --git a/csrc/runtime/matmul_tma.h b/csrc/runtime/matmul_tma.h deleted file mode 100644 index d859dc55e95..00000000000 --- a/csrc/runtime/matmul_tma.h +++ /dev/null @@ -1,22 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include - -namespace nvfuser { - -//! Run an SM90 TMA-based matmul (A[M,K] x B[K,N]) on the current CUDA stream. -//! Returns a new output tensor with the same dtype as the inputs. -at::Tensor matmulTma(const at::Tensor& a, const at::Tensor& b); - -//! Returns true when the handcrafted TMA matmul kernel can run for the inputs. -//! This is a non-throwing capability check intended for runtime dispatch. -bool canRunMatmulTma(const at::Tensor& a, const at::Tensor& b); - -} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index 5ce78686109..fe3556f9465 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -1,19 +1,33 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +// +// ========================================================================= +// Distributed Matmul Benchmark -- Test Harness +// +// Measures allgather + matmul throughput for C = A * B where A is +// row-sharded on M across ranks and B is replicated. Compares +// baseline (NCCL / CUDA P2P) and fused kernel implementations. +// +// See test_multidevice_fused_remote_matmul.h for the performance +// summary and implementation descriptions. +// ========================================================================= + +#include "test_multidevice_fused_remote_matmul.h" #include +#include +#include + #include #include #include -#include - #include "fusion.h" #include "host_ir/container.h" #include "ir/builder.h" @@ -22,341 +36,15 @@ #include "multidevice/cuda_p2p.h" #include "multidevice/ipc_handle.h" #include "multidevice/symmetric_tensor.h" -#include "runtime/matmul_tma.h" #include "tests/cpp/multidevice.h" namespace nvfuser { -double timeFusedRemoteMatmulMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_x, - int64_t block_y, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeNaiveRemoteMatmulCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - at::Tensor& a_chunk_signals, - int64_t a_chunk_pivot, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulMultimemMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulMultimemCutlassMs( - const __half* const* a_remote_shards, - __half* a_gathered_multicast_ptr, - const at::Tensor& a_gathered_local, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeNaiveRemoteMatmulFusedTmaMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( - const __half* const* a_remote_shards, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - -double timeSeparatedAllgatherMatmulMultimemFusedTmaMs( - const __half* const* a_remote_shards, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream); - namespace { -// Implementations compared by this benchmark: -// - naiveFusedKernel: -// A rank-local, handwritten remote-pointer matmul path. A is sharded on M and -// each output row pulls its owner shard directly from symmetric remote memory. -// This models a fused comm+compute style where data movement is embedded in -// the kernel access pattern. -// - baselinePytorchEagerNccl: -// Reference eager path using PyTorch process group NCCL allgather to rebuild -// full A on every rank, then regular eager matmul_out(A_full, B_full). -// - baselinePytorchEagerCuda: -// Same eager compute structure as the NCCL baseline, but communication uses -// nvFuser's CUDA backend allgather primitives (post/wait with symmetric-memory -// handles from cuda_p2p.h) before eager matmul_out. -// - gpuAllgatherFusedNaiveCompute: -// Single fused kernel with explicit internal stages: -// (1) stage full A-row from remote pointers with regular thread loads, -// (2) compute matmul for that row. Uses fused ready/done semaphores. -// - gpuAllgatherMultimemFusedNaiveCompute: -// Same fused-kernel staged structure, but stage-1 writes use multimem store -// instructions on a multicast pointer before compute, with fused semaphores. -// - *CutlassCompute variants: -// Keep communication path semantics and replace compute with Hopper CUTLASS -// TMA matmul. -enum class DistributedMatmulImpl { - naiveFusedKernel, - naiveFusedKernelCutlassCompute, - naiveFusedKernelFusedTma, - baselinePytorchEagerNccl, - baselinePytorchEagerCuda, - gpuAllgatherFusedNaiveCompute, - gpuAllgatherTmaCompute, - gpuAllgatherAsyncSchedulerTmaCompute, - gpuAllgatherFusedTma, - gpuAllgatherMultimemTmaCompute, - gpuAllgatherMultimemFusedTma, - gpuAllgatherMultimemFusedNaiveCompute -}; - -enum class TimeMeasurementMode { CudaEvents, CpuClock }; - -// Runtime kernel-launch knobs shared by all implementations. -enum class RuntimeParams { - NaiveBlockX, - NaiveBlockY, - StagedBlockThreads, - StagedGridBlocks -}; - -int64_t runtimeParam(RuntimeParams param) { - switch (param) { - case RuntimeParams::NaiveBlockX: - return 16; - case RuntimeParams::NaiveBlockY: - return 16; - case RuntimeParams::StagedBlockThreads: - return 256; - case RuntimeParams::StagedGridBlocks: - // <= 0 means auto-select from M in the kernel launcher. - return 0; - } - NVF_ERROR(false, "Unknown runtime parameter enum value."); - return 0; -} - -// Centralized benchmark knobs used by every implementation path. -struct BenchmarkConfig { - int64_t warmup_iters; - int64_t iters; - TimeMeasurementMode time_mode; - bool barrier_at_each_iteration; -}; - -// Optional runtime objects required by specific implementations. -struct BenchmarkResources { - c10d::Backend* nccl_backend = nullptr; - std::unique_ptr cuda_hic; - Communication* cuda_allgather_communication = nullptr; - std::unique_ptr cuda_allgather_handle; - at::Tensor a_allgathered_half_cuda; - at::Tensor a_gathered_threadload; - at::Tensor a_gathered_multimem; - std::unique_ptr a_gathered_multimem_sym; - at::Tensor stage_semaphore_multimem; - std::unique_ptr stage_semaphore_multimem_sym; - at::Tensor threadload_ready_semaphore; - std::unique_ptr threadload_ready_semaphore_sym; - at::Tensor threadload_done_semaphore; - std::unique_ptr threadload_done_semaphore_sym; - at::Tensor async_chunk_signals; -}; - -const char* implName(DistributedMatmulImpl impl) { - switch (impl) { - case DistributedMatmulImpl::naiveFusedKernel: - return "naiveFusedKernel"; - case DistributedMatmulImpl::naiveFusedKernelCutlassCompute: - return "naiveFusedKernelCutlassCompute"; - case DistributedMatmulImpl::naiveFusedKernelFusedTma: - return "naiveFusedKernelFusedTma"; - case DistributedMatmulImpl::baselinePytorchEagerNccl: - return "baselinePytorchEagerNccl"; - case DistributedMatmulImpl::baselinePytorchEagerCuda: - return "baselinePytorchEagerCuda"; - case DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute: - return "gpuAllgatherFusedNaiveCompute"; - case DistributedMatmulImpl::gpuAllgatherTmaCompute: - return "gpuAllgatherTmaCompute"; - case DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute: - return "gpuAllgatherAsyncSchedulerTmaCompute"; - case DistributedMatmulImpl::gpuAllgatherFusedTma: - return "gpuAllgatherFusedTma"; - case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: - return "gpuAllgatherMultimemTmaCompute"; - case DistributedMatmulImpl::gpuAllgatherMultimemFusedTma: - return "gpuAllgatherMultimemFusedTma"; - case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: - return "gpuAllgatherMultimemFusedNaiveCompute"; - } - NVF_ERROR(false, "Unknown implementation enum value: ", static_cast(impl)); -} - -bool isMulticastSupported(int64_t device_id) { - int is_multicast_supported = 0; - auto result = cuDeviceGetAttribute( - &is_multicast_supported, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, - static_cast(device_id)); - return result == CUDA_SUCCESS && is_multicast_supported != 0; -} - -bool canRunHopperCutlassCompute(const at::Tensor& a, const at::Tensor& b) { - if (!canRunMatmulTma(a, b)) { - return false; - } - auto* props = at::cuda::getDeviceProperties(a.get_device()); - // Restrict CUTLASS-compute benchmark variants to Hopper in this test. - return props->major == 9 && props->minor == 0; -} - -bool canRunAsyncSchedulerCompute( - const at::Tensor& a, - const at::Tensor& b, - const at::Tensor& a_chunk_signals) { - if (!canRunHopperCutlassCompute(a, b)) { - return false; - } - return a_chunk_signals.defined() && a_chunk_signals.is_cuda() && - a_chunk_signals.scalar_type() == at::ScalarType::Int && - a_chunk_signals.dim() == 1 && a_chunk_signals.numel() > 0 && - a.size(0) % a_chunk_signals.numel() == 0; -} +// ========================================================================= +// Timing helper +// ========================================================================= template double benchmarkLoopMs( @@ -364,761 +52,448 @@ double benchmarkLoopMs( Communicator* communicator, cudaStream_t stream, Fn&& run_once) { - NVF_CHECK(config.iters > 0, "iters must be > 0, got ", config.iters); - - // Warmup segment (not timed). + NVF_CHECK(config.iters > 0, "iters must be > 0"); for (int64_t i = 0; i < config.warmup_iters; ++i) { - if (config.barrier_at_each_iteration) { + if (config.barrier_at_each_iteration) communicator->barrier(); - } run_once(); } NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - // Timed segment with device-side timestamps. if (config.time_mode == TimeMeasurementMode::CudaEvents) { - // Time each iteration independently so optional barriers can remain outside - // the measured region while preserving per-iteration MAX reduction semantics. - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; + cudaEvent_t start, stop; NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - float total_ms = 0.0f; + float total_ms = 0.f; for (int64_t i = 0; i < config.iters; ++i) { - if (config.barrier_at_each_iteration) { + if (config.barrier_at_each_iteration) communicator->barrier(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventRecord(start, stream)); run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - float iter_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&iter_ms, start, stop)); - total_ms += iter_ms; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventSynchronize(stop)); + float ms; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventElapsedTime(&ms, start, stop)); + total_ms += ms; } NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(config.iters); + return static_cast(total_ms) / config.iters; } - // Timed segment with host-side timestamps (includes stream sync cost). double total_ms = 0.0; for (int64_t i = 0; i < config.iters; ++i) { - if (config.barrier_at_each_iteration) { + if (config.barrier_at_each_iteration) communicator->barrier(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - auto start = std::chrono::high_resolution_clock::now(); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaStreamSynchronize(stream)); + auto t0 = std::chrono::high_resolution_clock::now(); run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); - auto stop = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed = stop - start; - total_ms += elapsed.count(); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaStreamSynchronize(stream)); + auto t1 = std::chrono::high_resolution_clock::now(); + total_ms += + std::chrono::duration(t1 - t0) + .count(); } - return total_ms / static_cast(config.iters); + return total_ms / config.iters; } -BenchmarkResources initBenchmarkResources( - DistributedMatmulImpl impl, - Communicator* communicator, - const Team& all_devices, - int64_t world_size, - int64_t m, - int64_t k) { - BenchmarkResources resources; - // NCCL eager baseline resources. - if (impl == DistributedMatmulImpl::baselinePytorchEagerNccl) { - if (!communicator->isBackendAvailable(CommunicatorBackend::kNccl)) { - return resources; - } - resources.nccl_backend = - communicator->getBackendForTeam(all_devices, CommunicatorBackend::kNccl); - } +// ========================================================================= +// Resource helpers +// ========================================================================= - // CUDA backend eager baseline resources (symmetric allgather handle). - if (impl == DistributedMatmulImpl::baselinePytorchEagerCuda) { - resources.cuda_hic = std::make_unique(); - FusionGuard fg(resources.cuda_hic.get()); - auto* in_tv = makeContigTensor(2); - auto* out_tv = makeContigTensor(2); - DeviceMesh mesh = DeviceMesh::createForNumDevices(world_size); - in_tv->setDeviceMesh(mesh); - out_tv->setDeviceMesh(mesh); - resources.cuda_allgather_communication = IrBuilder::create( - CommunicationType::Allgather, - out_tv, - in_tv, - all_devices, - /*root=*/-1, - RedOpType::UNUSED, - CommunicatorBackend::kCuda); - resources.a_allgathered_half_cuda = SymmetricTensor::allocate( - {m, k}, at::ScalarType::Half, communicator->device()); - resources.cuda_allgather_handle = std::make_unique( - resources.cuda_allgather_communication, resources.a_allgathered_half_cuda); - } +bool needsThreadloadRes(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::threadloadGatherScalarCompute || + impl == I::threadloadGatherCutlassCompute; +} - if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || - impl == DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { - resources.a_gathered_threadload = at::empty( - {m, k}, - at::TensorOptions() - .dtype(at::kHalf) - .device(communicator->device()) - .layout(at::kStrided)); - } +bool needsMultimemRes(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::multimemGatherScalarCompute || + impl == I::multimemGatherCutlassCompute; +} - if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { - // Per-rank [writer_rank, row, vec4-int] semaphores for fused ready/done. - resources.threadload_ready_semaphore = SymmetricTensor::allocate( - {world_size, m, 4}, at::ScalarType::Int, communicator->device()); - resources.threadload_ready_semaphore.zero_(); - resources.threadload_ready_semaphore_sym = std::make_unique( - resources.threadload_ready_semaphore); - resources.threadload_ready_semaphore_sym->setupRemoteHandles( - "fused_remote_matmul_threadload_ready"); +bool needsCutlass(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::threadloadGatherCutlassCompute || + impl == I::multimemGatherCutlassCompute; +} - resources.threadload_done_semaphore = SymmetricTensor::allocate( - {world_size, m, 4}, at::ScalarType::Int, communicator->device()); - resources.threadload_done_semaphore.zero_(); - resources.threadload_done_semaphore_sym = std::make_unique( - resources.threadload_done_semaphore); - resources.threadload_done_semaphore_sym->setupRemoteHandles( - "fused_remote_matmul_threadload_done"); - } +struct OwnedResources { + std::unique_ptr a_sym; + std::unique_ptr ready_sym; + std::unique_ptr done_sym; + std::unique_ptr stage_sym; + std::unique_ptr multimem_sym; + std::unique_ptr cuda_hic; + std::unique_ptr cuda_ag_handle; + at::Tensor ready_t, done_t, stage_t; +}; - if (impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute) { - resources.async_chunk_signals = at::zeros( - {world_size}, - at::TensorOptions().dtype(at::kInt).device(communicator->device())); +void initResources( + DistributedMatmulImpl impl, + Communicator* comm, + const Team& team, + int64_t ws, int64_t m, int64_t k, + OwnedResources& res, + DistributedMatmulContext& ctx) { + using I = DistributedMatmulImpl; + auto dev = comm->device(); + + if (impl == I::baselineNcclAllgatherMatmul) { + if (comm->isBackendAvailable(CommunicatorBackend::kNccl)) + ctx.nccl_backend = comm->getBackendForTeam( + team, CommunicatorBackend::kNccl); } - if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) { - resources.a_gathered_multimem = SymmetricTensor::allocate( - {m, k}, at::ScalarType::Half, communicator->device()); - resources.a_gathered_multimem_sym = - std::make_unique(resources.a_gathered_multimem); - resources.a_gathered_multimem_sym->setupMulticast( - /*exporter_rank=*/0, "fused_remote_matmul_staged_multimem"); + if (impl == I::baselineCudaAllgatherMatmul) { + res.cuda_hic = + std::make_unique(); + FusionGuard fg(res.cuda_hic.get()); + auto* itv = makeContigTensor(2); + auto* otv = makeContigTensor(2); + auto mesh = DeviceMesh::createForNumDevices(ws); + itv->setDeviceMesh(mesh); + otv->setDeviceMesh(mesh); + auto* cir = IrBuilder::create( + CommunicationType::Allgather, otv, itv, + team, -1, RedOpType::UNUSED, + CommunicatorBackend::kCuda); + ctx.a_allgathered_cuda = SymmetricTensor::allocate( + {m, k}, at::ScalarType::Half, dev); + res.cuda_ag_handle = + std::make_unique( + cir, ctx.a_allgathered_cuda); + ctx.cuda_comm = cir; + ctx.cuda_handle = res.cuda_ag_handle.get(); + } - // Per-rank semaphore rows used by the fused multimem kernel barrier. - // Shape is [writer_rank, row, vec4-int] so each writer can publish one - // epoch per row, and each reader can wait on all writers for that row. - resources.stage_semaphore_multimem = SymmetricTensor::allocate( - {world_size, m, 4}, at::ScalarType::Int, communicator->device()); - resources.stage_semaphore_multimem.zero_(); - resources.stage_semaphore_multimem_sym = - std::make_unique(resources.stage_semaphore_multimem); - resources.stage_semaphore_multimem_sym->setupRemoteHandles( - "fused_remote_matmul_stage_semaphore"); + if (needsThreadloadRes(impl)) { + ctx.a_gathered = at::empty( + {m, k}, + at::TensorOptions().dtype(at::kHalf).device(dev)); + + auto make_sem = [&](const char* tag) + -> std::pair> { + at::Tensor t = SymmetricTensor::allocate( + {ws, m, 4}, at::ScalarType::Int, dev); + t.zero_(); + auto s = std::make_unique(t); + s->setupRemoteHandles(tag); + return {t, std::move(s)}; + }; + + auto [rt, rs] = make_sem("fused_matmul_ready"); + res.ready_t = rt; + res.ready_sym = std::move(rs); + ctx.ready_sem_remote = + reinterpret_cast( + res.ready_sym->devicePeerPointers()); + ctx.ready_sem_local = reinterpret_cast( + res.ready_sym->localTensor().data_ptr()); + + auto [dt, ds] = make_sem("fused_matmul_done"); + res.done_t = dt; + res.done_sym = std::move(ds); + ctx.done_sem_remote = + reinterpret_cast( + res.done_sym->devicePeerPointers()); + ctx.done_sem_local = reinterpret_cast( + res.done_sym->localTensor().data_ptr()); } - return resources; -} -double reduceMaxTimeMs(Communicator* communicator, double local_ms_per_iter) { - // Reduce per-rank timing with MAX so throughput reflects slowest rank. - at::Tensor max_time_tensor = at::tensor( - {static_cast(local_ms_per_iter)}, - at::TensorOptions().dtype(at::kFloat).device(communicator->device())); - std::vector time_tensors = {max_time_tensor}; - communicator->getWorld()->allreduce(time_tensors, {c10d::ReduceOp::MAX})->wait(); - return static_cast(max_time_tensor.item()); + if (needsMultimemRes(impl)) { + ctx.a_gathered_multimem = SymmetricTensor::allocate( + {m, k}, at::ScalarType::Half, dev); + res.multimem_sym = std::make_unique( + ctx.a_gathered_multimem); + res.multimem_sym->setupMulticast( + 0, "fused_matmul_mc"); + ctx.multicast_ptr = reinterpret_cast<__half*>( + res.multimem_sym->multicastPtr()); + + res.stage_t = SymmetricTensor::allocate( + {ws, m, 4}, at::ScalarType::Int, dev); + res.stage_t.zero_(); + res.stage_sym = + std::make_unique(res.stage_t); + res.stage_sym->setupRemoteHandles( + "fused_matmul_stage"); + ctx.stage_sem_remote = + reinterpret_cast( + res.stage_sym->devicePeerPointers()); + ctx.stage_sem_local = reinterpret_cast( + res.stage_sym->localTensor().data_ptr()); + } } -double timeBaselinePytorchEagerMs( - c10d::Backend* backend, - at::Tensor& a_local_half, - const at::Tensor& b_full_half, - at::Tensor& c_out_half, - const BenchmarkConfig& config, - Communicator* communicator, - cudaStream_t stream) { - at::Tensor a_allgathered_half = at::empty( - {a_local_half.size(0) * backend->getSize(), a_local_half.size(1)}, - a_local_half.options()); - auto run_once = [&]() { - backend->_allgather_base(a_allgathered_half, a_local_half)->wait(); - at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - }; - return benchmarkLoopMs(config, communicator, stream, run_once); +double reduceMaxTimeMs( + Communicator* comm, double local_ms) { + at::Tensor t = at::tensor( + {static_cast(local_ms)}, + at::TensorOptions() + .dtype(at::kFloat) + .device(comm->device())); + std::vector tv = {t}; + comm->getWorld() + ->allreduce(tv, {c10d::ReduceOp::MAX}) + ->wait(); + return static_cast(t.item()); } -double timeBaselinePytorchEagerCudaMs( - Communication* communication, - SymMemForAllgather* allgather_handle, - at::Tensor& a_local_half, - at::Tensor& a_allgathered_half, - const at::Tensor& b_full_half, - at::Tensor& c_out_half, - const BenchmarkConfig& config, - Communicator* communicator, - cudaStream_t stream) { - auto run_once = [&]() { - postWithCudaBackend( - communication, - a_local_half, - allgather_handle, - (CUstream)stream, - /*root=*/-1); - waitWithCudaBackend( - communication, - allgather_handle, - (CUstream)stream, - /*root=*/-1); - at::matmul_out(c_out_half, a_allgathered_half, b_full_half); - }; - return benchmarkLoopMs(config, communicator, stream, run_once); -} +// ========================================================================= +// Implementation dispatcher +// +// Each case builds a run_once lambda and wraps it with +// benchmarkLoopMs. Kernel launchers live in the .cu file. +// ========================================================================= -double timeImplementationMs( +double runImplementation( DistributedMatmulImpl impl, - const BenchmarkConfig& config, - Communicator* communicator, - c10d::Backend* nccl_backend, - Communication* cuda_allgather_communication, - SymMemForAllgather* cuda_allgather_handle, - const __half* const* device_remote_ptrs, - at::Tensor& a_local_half, - at::Tensor& a_allgathered_half_cuda, - at::Tensor& a_gathered_threadload, - at::Tensor& a_gathered_multimem, - at::Tensor& async_chunk_signals, - SymmetricTensor* threadload_ready_semaphore_sym, - SymmetricTensor* threadload_done_semaphore_sym, - SymmetricTensor* a_gathered_multimem_sym, - SymmetricTensor* stage_semaphore_multimem_sym, - int32_t* const* threadload_ready_semaphore_remote_ptrs, - int32_t* const* threadload_done_semaphore_remote_ptrs, - int32_t* const* stage_semaphore_remote_ptrs, - const at::Tensor& b_full_half, - at::Tensor& c_out_half, - int64_t my_rank, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - cudaStream_t stream) { - // Dispatch to implementation-specific execution path. - (void)a_gathered_multimem; - const int64_t naive_block_x = runtimeParam(RuntimeParams::NaiveBlockX); - const int64_t naive_block_y = runtimeParam(RuntimeParams::NaiveBlockY); - const int64_t staged_block_threads = - runtimeParam(RuntimeParams::StagedBlockThreads); - const int64_t staged_grid_blocks = - runtimeParam(RuntimeParams::StagedGridBlocks); + DistributedMatmulContext& ctx, + const BenchmarkConfig& config) { + using I = DistributedMatmulImpl; switch (impl) { - case DistributedMatmulImpl::naiveFusedKernel: { - auto run_once = [&]() { - timeFusedRemoteMatmulMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - world_size, - m, - n, - k, - m_per_rank, - naive_block_x, - naive_block_y, - /*warmup_iters=*/0, - /*iters=*/1, - stream); + case I::baselineNcclAllgatherMatmul: { + at::Tensor a_full = at::empty( + {ctx.m, ctx.k}, ctx.a_local_half.options()); + auto run = [&]() { + ctx.nccl_backend + ->_allgather_base(a_full, ctx.a_local_half) + ->wait(); + at::matmul_out( + ctx.c_out_half, a_full, ctx.b_full_half); + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::baselineCudaAllgatherMatmul: { + auto run = [&]() { + postWithCudaBackend( + ctx.cuda_comm, ctx.a_local_half, + ctx.cuda_handle, + (CUstream)ctx.stream, -1); + waitWithCudaBackend( + ctx.cuda_comm, ctx.cuda_handle, + (CUstream)ctx.stream, -1); + at::matmul_out( + ctx.c_out_half, ctx.a_allgathered_cuda, + ctx.b_full_half); + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::naiveRemoteRead: { + auto run = [&]() { + launchNaiveRemoteRead(ctx); + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::threadloadGatherScalarCompute: { + int64_t epoch = 0; + auto run = [&]() { + launchThreadloadGather( + ctx, static_cast(epoch), true); + ++epoch; }; - return benchmarkLoopMs(config, communicator, stream, run_once); + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::threadloadGatherCutlassCompute: { + int64_t epoch = 0; + auto run = [&]() { + launchThreadloadGather( + ctx, static_cast(epoch), false); + ctx.c_out_half = + matmulTma(ctx.a_gathered, ctx.b_full_half); + ++epoch; + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::multimemGatherScalarCompute: { + int64_t epoch = 0; + auto run = [&]() { + launchMultimemGather( + ctx, static_cast(epoch), true); + ++epoch; + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); + } + case I::multimemGatherCutlassCompute: { + int64_t epoch = 0; + auto run = [&]() { + launchMultimemGather( + ctx, static_cast(epoch), false); + ctx.c_out_half = matmulTma( + ctx.a_gathered_multimem, ctx.b_full_half); + ++epoch; + }; + return benchmarkLoopMs( + config, ctx.communicator, ctx.stream, run); } - case DistributedMatmulImpl::naiveFusedKernelCutlassCompute: - return timeNaiveRemoteMatmulCutlassMs( - device_remote_ptrs, - a_gathered_threadload, - b_full_half, - c_out_half, - m, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::naiveFusedKernelFusedTma: - return timeNaiveRemoteMatmulFusedTmaMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - m, - n, - k, - m_per_rank, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::baselinePytorchEagerNccl: - NVF_CHECK( - nccl_backend != nullptr, - "baselinePytorchEagerNccl requires a valid NCCL process group backend."); - return timeBaselinePytorchEagerMs( - nccl_backend, - a_local_half, - b_full_half, - c_out_half, - config, - communicator, - stream); - case DistributedMatmulImpl::baselinePytorchEagerCuda: - NVF_CHECK( - cuda_allgather_communication != nullptr && cuda_allgather_handle != nullptr, - "baselinePytorchEagerCuda requires initialized CUDA allgather resources."); - return timeBaselinePytorchEagerCudaMs( - cuda_allgather_communication, - cuda_allgather_handle, - a_local_half, - a_allgathered_half_cuda, - b_full_half, - c_out_half, - config, - communicator, - stream); - case DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute: - NVF_CHECK( - threadload_ready_semaphore_sym != nullptr && - threadload_done_semaphore_sym != nullptr && - threadload_ready_semaphore_remote_ptrs != nullptr && - threadload_done_semaphore_remote_ptrs != nullptr, - "gpuAllgatherFusedNaiveCompute requires semaphore resources."); - return timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - reinterpret_cast<__half*>(a_gathered_threadload.data_ptr()), - threadload_ready_semaphore_remote_ptrs, - reinterpret_cast( - threadload_ready_semaphore_sym->localTensor().data_ptr()), - threadload_done_semaphore_remote_ptrs, - reinterpret_cast( - threadload_done_semaphore_sym->localTensor().data_ptr()), - my_rank, - world_size, - m, - n, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherTmaCompute: - NVF_CHECK( - threadload_ready_semaphore_sym != nullptr && - threadload_done_semaphore_sym != nullptr && - threadload_ready_semaphore_remote_ptrs != nullptr && - threadload_done_semaphore_remote_ptrs != nullptr, - "gpuAllgatherTmaCompute requires semaphore resources."); - return timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( - device_remote_ptrs, - a_gathered_threadload, - threadload_ready_semaphore_remote_ptrs, - reinterpret_cast( - threadload_ready_semaphore_sym->localTensor().data_ptr()), - threadload_done_semaphore_remote_ptrs, - reinterpret_cast( - threadload_done_semaphore_sym->localTensor().data_ptr()), - my_rank, - world_size, - b_full_half, - c_out_half, - m, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute: - NVF_CHECK( - threadload_ready_semaphore_sym != nullptr && - threadload_done_semaphore_sym != nullptr && - threadload_ready_semaphore_remote_ptrs != nullptr && - threadload_done_semaphore_remote_ptrs != nullptr, - "gpuAllgatherAsyncSchedulerTmaCompute requires semaphore resources."); - NVF_CHECK( - async_chunk_signals.defined(), - "gpuAllgatherAsyncSchedulerTmaCompute requires chunk signals."); - return timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( - device_remote_ptrs, - a_gathered_threadload, - threadload_ready_semaphore_remote_ptrs, - reinterpret_cast( - threadload_ready_semaphore_sym->localTensor().data_ptr()), - threadload_done_semaphore_remote_ptrs, - reinterpret_cast( - threadload_done_semaphore_sym->localTensor().data_ptr()), - my_rank, - world_size, - b_full_half, - c_out_half, - async_chunk_signals, - /*a_chunk_pivot=*/my_rank, - m, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherFusedTma: - NVF_CHECK( - threadload_ready_semaphore_sym != nullptr && - threadload_done_semaphore_sym != nullptr && - threadload_ready_semaphore_remote_ptrs != nullptr && - threadload_done_semaphore_remote_ptrs != nullptr, - "gpuAllgatherFusedTma requires semaphore resources."); - return timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( - device_remote_ptrs, - threadload_ready_semaphore_remote_ptrs, - reinterpret_cast( - threadload_ready_semaphore_sym->localTensor().data_ptr()), - threadload_done_semaphore_remote_ptrs, - reinterpret_cast( - threadload_done_semaphore_sym->localTensor().data_ptr()), - my_rank, - world_size, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - m, - n, - k, - m_per_rank, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute: - NVF_CHECK( - a_gathered_multimem_sym != nullptr && - stage_semaphore_multimem_sym != nullptr && - stage_semaphore_remote_ptrs != nullptr, - "gpuAllgatherMultimemTmaCompute requires staging and " - "semaphore tensors."); - return timeSeparatedAllgatherMatmulMultimemCutlassMs( - device_remote_ptrs, - reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), - a_gathered_multimem, - stage_semaphore_remote_ptrs, - reinterpret_cast( - stage_semaphore_multimem_sym->localTensor().data_ptr()), - my_rank, - world_size, - b_full_half, - c_out_half, - m, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherMultimemFusedTma: - NVF_CHECK( - a_gathered_multimem_sym != nullptr && - stage_semaphore_multimem_sym != nullptr && - stage_semaphore_remote_ptrs != nullptr, - "gpuAllgatherMultimemFusedTma requires staging and semaphore tensors."); - return timeSeparatedAllgatherMatmulMultimemFusedTmaMs( - device_remote_ptrs, - reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), - stage_semaphore_remote_ptrs, - reinterpret_cast( - stage_semaphore_multimem_sym->localTensor().data_ptr()), - my_rank, - world_size, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - m, - n, - k, - m_per_rank, - config.warmup_iters, - config.iters, - stream); - case DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute: - NVF_CHECK( - a_gathered_multimem_sym != nullptr && - stage_semaphore_multimem_sym != nullptr, - "gpuAllgatherMultimemFusedNaiveCompute requires staging and semaphore tensors."); - return timeSeparatedAllgatherMatmulMultimemMs( - device_remote_ptrs, - reinterpret_cast(b_full_half.data_ptr()), - reinterpret_cast<__half*>(c_out_half.data_ptr()), - reinterpret_cast<__half*>(a_gathered_multimem_sym->multicastPtr()), - stage_semaphore_remote_ptrs, - reinterpret_cast( - stage_semaphore_multimem_sym->localTensor().data_ptr()), - my_rank, - world_size, - m, - n, - k, - m_per_rank, - staged_block_threads, - staged_grid_blocks, - config.warmup_iters, - config.iters, - stream); } - NVF_ERROR(false, "Unsupported implementation enum: ", static_cast(impl)); + NVF_ERROR(false, "Unknown implementation."); } -} // namespace +} // anonymous namespace + +// ========================================================================= +// Test fixture +// ========================================================================= -class FusedRemoteMatmulTest : public MultiDeviceTest, - public testing::WithParamInterface< - DistributedMatmulImpl> { +class FusedRemoteMatmulTest + : public MultiDeviceTest, + public testing::WithParamInterface< + DistributedMatmulImpl> { protected: - static constexpr BenchmarkConfig kBenchmarkConfig = { + static constexpr BenchmarkConfig kConfig = { /*warmup_iters=*/8, /*iters=*/30, /*time_mode=*/TimeMeasurementMode::CpuClock, /*barrier_at_each_iteration=*/false}; - static constexpr BenchmarkConfig kCorrectnessConfig = { - /*warmup_iters=*/3, - /*iters=*/1, - /*time_mode=*/TimeMeasurementMode::CpuClock, - /*barrier_at_each_iteration=*/false}; }; -// Benchmark context: -// - A is sharded on M across ranks, B is replicated. -// - We compare three execution paths under identical setup/validation: -// fused remote-pointer kernel, NCCL allgather+eager matmul, and CUDA-backend -// allgather+eager matmul. -// - Rank 0 reports throughput using MAX latency reduced across ranks. -TEST_P(FusedRemoteMatmulTest, DistributedMatmulRemotePointerFused) { - // ---------- Preconditions ---------- - if (!communicator_->is_available()) { - GTEST_SKIP() << "Communicator is unavailable."; - } - if (communicator_->size() == 1) { - GTEST_SKIP() << "Needs at least 2 devices."; - } +TEST_P(FusedRemoteMatmulTest, DistributedMatmul) { + if (!communicator_->is_available()) + GTEST_SKIP() << "Communicator unavailable."; + if (communicator_->size() == 1) + GTEST_SKIP() << "Needs >= 2 devices."; - const int64_t world_size = communicator_->size(); - const int64_t my_rank = communicator_->deviceId(); + const int64_t ws = communicator_->size(); + const int64_t rank = communicator_->deviceId(); const auto impl = GetParam(); - if ((impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) && - !isMulticastSupported(my_rank)) { - GTEST_SKIP() << "Multicast is not supported on this device."; - } - - // ---------- Problem shape ---------- - Team all_devices(world_size); - std::iota(all_devices.begin(), all_devices.end(), 0); - - constexpr int64_t m = 1024; - constexpr int64_t k = 1024; - constexpr int64_t n = 1024; - NVF_ERROR(m % world_size == 0, "M must be divisible by world size."); - const int64_t m_per_rank = m / world_size; + if (needsMultimemRes(impl) && + !isMulticastSupported(rank)) + GTEST_SKIP() << "Multicast unsupported."; - // ---------- Inputs ---------- - const auto cpu_float_opts = - at::TensorOptions().dtype(at::kFloat).device(at::kCPU); - const auto gpu_half_opts = - at::TensorOptions().dtype(at::kHalf).device(communicator_->device()); + // ---- Problem shape ---- + constexpr int64_t m = 1024, k = 1024, n = 1024; + NVF_ERROR(m % ws == 0); + const int64_t mpr = m / ws; + Team team(ws); + std::iota(team.begin(), team.end(), 0); - // Deterministic inputs on every rank for fair cross-impl comparison. + // ---- Inputs ---- at::manual_seed(0); - at::Tensor a_full_cpu = at::randn({m, k}, cpu_float_opts); - at::Tensor b_full_cpu = at::randn({k, n}, cpu_float_opts); - - at::Tensor a_local_half = a_full_cpu - .slice(0, my_rank * m_per_rank, (my_rank + 1) * m_per_rank) - .to(gpu_half_opts.device(), at::kHalf); - at::Tensor b_full_half = b_full_cpu.to(gpu_half_opts.device(), at::kHalf); - - at::Tensor a_local_sym = SymmetricTensor::allocate( - {m_per_rank, k}, at::ScalarType::Half, communicator_->device()); - a_local_sym.copy_(a_local_half); - SymmetricTensor symmetric_a(a_local_sym); - symmetric_a.setupRemoteHandles("fused_remote_matmul_a"); - - const __half* const* device_remote_ptrs = - reinterpret_cast(symmetric_a.devicePeerPointers()); - - // ---------- Outputs and stream ---------- - at::Tensor c_out_half = at::zeros({m, n}, gpu_half_opts); - at::Tensor a_allgathered_half_cuda; - at::Tensor a_gathered_threadload; - at::Tensor a_gathered_multimem; - at::Tensor async_chunk_signals; - c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( - /*isHighPriority=*/false, - static_cast(communicator_->device().index())); - c10::cuda::CUDAStreamGuard stream_guard(test_stream); - cudaStream_t stream = test_stream.stream(); - - auto resources = - initBenchmarkResources(impl, communicator_, all_devices, world_size, m, k); - if (impl == DistributedMatmulImpl::baselinePytorchEagerNccl && - resources.nccl_backend == nullptr) { - GTEST_SKIP() << "NCCL backend unavailable for baselinePytorchEagerNccl."; + auto cpu_f = at::TensorOptions().dtype(at::kFloat); + auto gpu_h = at::TensorOptions() + .dtype(at::kHalf) + .device(communicator_->device()); + at::Tensor a_full = at::randn({m, k}, cpu_f); + at::Tensor b_full = at::randn({k, n}, cpu_f); + at::Tensor a_local = + a_full.slice(0, rank * mpr, (rank + 1) * mpr) + .to(gpu_h.device(), at::kHalf); + at::Tensor b_gpu = b_full.to(gpu_h.device(), at::kHalf); + + at::Tensor a_sym = SymmetricTensor::allocate( + {mpr, k}, at::ScalarType::Half, + communicator_->device()); + a_sym.copy_(a_local); + OwnedResources res; + res.a_sym = std::make_unique(a_sym); + res.a_sym->setupRemoteHandles("fused_matmul_a"); + + // ---- Build context ---- + DistributedMatmulContext ctx; + ctx.m = m; + ctx.n = n; + ctx.k = k; + ctx.m_per_rank = mpr; + ctx.my_rank = rank; + ctx.world_size = ws; + ctx.device_remote_ptrs = + reinterpret_cast( + res.a_sym->devicePeerPointers()); + ctx.a_local_half = a_local; + ctx.b_full_half = b_gpu; + ctx.c_out_half = at::zeros({m, n}, gpu_h); + ctx.communicator = communicator_; + + c10::cuda::CUDAStream test_stream = + c10::cuda::getStreamFromPool( + false, + static_cast( + communicator_->device().index())); + c10::cuda::CUDAStreamGuard guard(test_stream); + ctx.stream = test_stream.stream(); + + initResources( + impl, communicator_, team, ws, m, k, res, ctx); + + // ---- Capability gates ---- + if (impl == + DistributedMatmulImpl:: + baselineNcclAllgatherMatmul && + ctx.nccl_backend == nullptr) + GTEST_SKIP() << "NCCL backend unavailable."; + + if (needsCutlass(impl)) { + at::Tensor ref = ctx.a_gathered.defined() + ? ctx.a_gathered + : ctx.a_gathered_multimem; + if (!canRunCutlassCompute(ref, b_gpu)) + GTEST_SKIP() << "CUTLASS needs Hopper SM90."; } - a_allgathered_half_cuda = resources.a_allgathered_half_cuda; - a_gathered_threadload = resources.a_gathered_threadload; - a_gathered_multimem = resources.a_gathered_multimem; - async_chunk_signals = resources.async_chunk_signals; - int32_t* const* device_stage_semaphore_remote_ptrs = nullptr; - int32_t* const* device_threadload_ready_semaphore_remote_ptrs = nullptr; - int32_t* const* device_threadload_done_semaphore_remote_ptrs = nullptr; - if (impl == DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherFusedTma) { - NVF_CHECK( - resources.threadload_ready_semaphore_sym != nullptr && - resources.threadload_done_semaphore_sym != nullptr, - "Missing synchronized threadload semaphore resources."); - device_threadload_ready_semaphore_remote_ptrs = - reinterpret_cast( - resources.threadload_ready_semaphore_sym->devicePeerPointers()); - device_threadload_done_semaphore_remote_ptrs = - reinterpret_cast( - resources.threadload_done_semaphore_sym->devicePeerPointers()); - } - - if (impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma) { - NVF_CHECK( - resources.stage_semaphore_multimem_sym != nullptr, - "Missing staged multimem semaphore resources."); - device_stage_semaphore_remote_ptrs = - reinterpret_cast( - resources.stage_semaphore_multimem_sym->devicePeerPointers()); - } - - const bool needs_cutlass_compute = impl == - DistributedMatmulImpl::naiveFusedKernelCutlassCompute || - impl == DistributedMatmulImpl::gpuAllgatherTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute || - impl == DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute; - const bool needs_fused_tma_compute = impl == - DistributedMatmulImpl::naiveFusedKernelFusedTma || - impl == DistributedMatmulImpl::gpuAllgatherFusedTma || - impl == DistributedMatmulImpl::gpuAllgatherMultimemFusedTma; - if (needs_cutlass_compute && - !(impl == DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute - ? canRunAsyncSchedulerCompute( - a_gathered_threadload, b_full_half, async_chunk_signals) - : canRunHopperCutlassCompute( - a_gathered_threadload.defined() ? a_gathered_threadload - : a_gathered_multimem, - b_full_half))) { - GTEST_SKIP() - << "CUTLASS-compute variants require Hopper SM90 with TMA support."; - } - if (needs_fused_tma_compute && - !canRunHopperCutlassCompute(a_local_half, b_full_half)) { - GTEST_SKIP() << "FusedTma variants require Hopper SM90 support."; - } - - auto run_implementation = [&](const BenchmarkConfig& config) { - return timeImplementationMs( - impl, - config, - communicator_, - resources.nccl_backend, - resources.cuda_allgather_communication, - resources.cuda_allgather_handle.get(), - device_remote_ptrs, - a_local_half, - a_allgathered_half_cuda, - a_gathered_threadload, - a_gathered_multimem, - async_chunk_signals, - resources.threadload_ready_semaphore_sym.get(), - resources.threadload_done_semaphore_sym.get(), - resources.a_gathered_multimem_sym.get(), - resources.stage_semaphore_multimem_sym.get(), - device_threadload_ready_semaphore_remote_ptrs, - device_threadload_done_semaphore_remote_ptrs, - device_stage_semaphore_remote_ptrs, - b_full_half, - c_out_half, - my_rank, - world_size, - m, - n, - k, - m_per_rank, - stream); - }; - - // ---------- Correctness ---------- - // Run once before validation to execute the selected implementation path. - (void)run_implementation(kCorrectnessConfig); - - at::Tensor c_ref_cpu = at::matmul(a_full_cpu, b_full_cpu); - at::Tensor c_out_cpu = c_out_half.cpu().to(at::kFloat); - EXPECT_TRUE(c_out_cpu.allclose(c_ref_cpu, 2e-1, 2e-1)) - << "Fused remote-pointer matmul output mismatch."; - - // ---------- Benchmark ---------- + // ---- Correctness (1 iteration, no warmup) ---- + (void)runImplementation( + impl, ctx, + {0, 1, TimeMeasurementMode::CpuClock, false}); + at::Tensor c_ref = at::matmul(a_full, b_full); + EXPECT_TRUE( + ctx.c_out_half.cpu().to(at::kFloat).allclose( + c_ref, 2e-1, 2e-1)) + << "Mismatch for " << implName(impl); + + // ---- Benchmark ---- communicator_->barrier(); - const double local_ms_per_iter = run_implementation(kBenchmarkConfig); + double local_ms = + runImplementation(impl, ctx, kConfig); communicator_->barrier(); - // Distributed throughput is constrained by the slowest rank. - const double global_ms_per_iter = - reduceMaxTimeMs(communicator_, local_ms_per_iter); - - // ---------- Reporting ---------- - const double flops = 2.0 * static_cast(m) * static_cast(n) * - static_cast(k); - const double tflops = flops / (global_ms_per_iter * 1.0e9); - if (my_rank == 0) { - std::cout << "[perf] fused_remote_matmul impl=" << implName(impl) + double global_ms = + reduceMaxTimeMs(communicator_, local_ms); + + // ---- Report ---- + double tflops = + 2.0 * m * n * k / (global_ms * 1e9); + if (rank == 0) { + std::cout << "[perf] fused_remote_matmul" + << " impl=" << implName(impl) << " M=" << m << " N=" << n << " K=" << k - << " world_size=" << world_size << " : " << global_ms_per_iter - << " ms/iter, " << tflops << " TFLOP/s" << std::endl; + << " world_size=" << ws << " : " + << global_ms << " ms/iter, " << tflops + << " TFLOP/s" << std::endl; } - } INSTANTIATE_TEST_SUITE_P( , FusedRemoteMatmulTest, testing::Values( - DistributedMatmulImpl::naiveFusedKernel, - DistributedMatmulImpl::naiveFusedKernelCutlassCompute, - DistributedMatmulImpl::naiveFusedKernelFusedTma, - DistributedMatmulImpl::baselinePytorchEagerNccl, - DistributedMatmulImpl::baselinePytorchEagerCuda, - DistributedMatmulImpl::gpuAllgatherFusedNaiveCompute, - DistributedMatmulImpl::gpuAllgatherTmaCompute, - DistributedMatmulImpl::gpuAllgatherAsyncSchedulerTmaCompute, - DistributedMatmulImpl::gpuAllgatherFusedTma, - DistributedMatmulImpl::gpuAllgatherMultimemTmaCompute, - DistributedMatmulImpl::gpuAllgatherMultimemFusedTma, - DistributedMatmulImpl::gpuAllgatherMultimemFusedNaiveCompute), - [](const testing::TestParamInfo& info) { + DistributedMatmulImpl::baselineNcclAllgatherMatmul, + DistributedMatmulImpl::baselineCudaAllgatherMatmul, + DistributedMatmulImpl::naiveRemoteRead, + DistributedMatmulImpl::threadloadGatherScalarCompute, + DistributedMatmulImpl::multimemGatherScalarCompute, + DistributedMatmulImpl::threadloadGatherCutlassCompute, + DistributedMatmulImpl::multimemGatherCutlassCompute), + [](const testing::TestParamInfo< + DistributedMatmulImpl>& info) { return implName(info.param); }); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index be7da99bc5c..d916aba7e1e 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -1,1206 +1,567 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +// +// ========================================================================= +// Distributed Matmul Kernels +// +// This file contains CUDA kernels and host-side launcher functions for +// the distributed matmul benchmark. Each kernel combines a +// communication strategy with a compute strategy: +// +// Communication strategies: +// - Naive remote read: each thread reads A directly from the owner +// rank via remote pointers. +// - Threadload gather: cooperative thread loads stage A rows into a +// local buffer, synchronized via ready/done semaphores. +// - Multimem gather: owner rank writes A rows to a multicast buffer +// using multimem.st (Hopper SM90+), synchronized via semaphores. +// +// Compute strategies: +// - Scalar: each thread accumulates one output element. +// - CUTLASS TMA: host-launched Hopper GEMM (in the .cpp file, not +// here -- the kernel is launched with n=0 to skip in-kernel compute). +// +// The CUTLASS TMA matmul wrapper (matmulTma) is also defined here, +// moved from csrc/runtime/matmul_tma.cu for self-containment. +// ========================================================================= + +#include "test_multidevice_fused_remote_matmul.h" #include -#include -#include -#include +#include +#include +#include -#include "cuda_utils.h" -#include "runtime/matmul_tma.h" +#include + +// CUTLASS TMA matmul (Hopper SM90) +#if defined(NVFUSER_ENABLE_CUTLASS) +#if !defined(__CUDACC_VER_MAJOR__) +#define __CUDACC_VER_MAJOR__ 13 +#define __CUDACC_VER_MINOR__ 0 +#endif +#include "cutlass/arch/config.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +#endif namespace nvfuser { namespace { -constexpr int64_t kSemaphoreVecWidth = 4; -constexpr int64_t kMaxSemaphorePollIters = 1LL << 26; -constexpr int64_t kMmaTile = 16; - -__device__ inline void publishEpochToAllRanks( - int32_t* const* remote_semaphore_ptrs, - int32_t* local_semaphore, - int64_t writer_rank, - int64_t row, - int64_t m, - int64_t world_size, - int32_t epoch) { - int32_t* my_row_local = - local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - my_row_local[vec_i] = epoch; - } - __threadfence_system(); - for (int64_t peer = 0; peer < world_size; ++peer) { - int32_t* peer_row_remote = - remote_semaphore_ptrs[peer] + (writer_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - peer_row_remote[vec_i] = epoch; - } - } - __threadfence_system(); -} +// ========================================================================= +// Section 1: CUTLASS TMA matmul wrapper +// +// Provides matmulTma() -- a Hopper SM90 GEMM using CUTLASS 3.x with +// TMA loads. Moved from csrc/runtime/matmul_tma.cu so the benchmark +// is self-contained. +// ========================================================================= -__device__ inline void publishEpochToRank( - int32_t* remote_semaphore_for_target_rank, - int64_t writer_rank, - int64_t row, - int64_t m, - int32_t epoch) { - int32_t* row_remote = - remote_semaphore_for_target_rank + (writer_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - row_remote[vec_i] = epoch; - } - __threadfence_system(); +bool hasValidTmaShape( + const at::Tensor& a, + const at::Tensor& b) { + if (!a.defined() || !b.defined()) return false; + if (!a.is_cuda() || !b.is_cuda()) return false; + if (a.dim() != 2 || b.dim() != 2) return false; + if (a.scalar_type() != b.scalar_type()) return false; + if (!(a.scalar_type() == at::ScalarType::Half || + a.scalar_type() == at::ScalarType::BFloat16)) + return false; + if (!a.is_contiguous() || !b.is_contiguous()) return false; + if (a.size(1) != b.size(0)) return false; + if (a.get_device() != b.get_device()) return false; + constexpr int64_t kAlign = 8; + if (a.size(1) % kAlign != 0 || b.size(1) % kAlign != 0) + return false; + return true; } -__device__ inline void setLocalEpoch( - int32_t* local_semaphore, - int64_t writer_rank, - int64_t row, - int64_t m, - int32_t epoch) { - int32_t* row_local = - local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth; - for (int64_t vec_i = 0; vec_i < kSemaphoreVecWidth; ++vec_i) { - row_local[vec_i] = epoch; - } - __threadfence_system(); +#if defined(NVFUSER_ENABLE_CUTLASS) && \ + defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +using namespace cute; + +template +struct TmaSm90Config { + using EA = ElementT; + using EB = ElementT; + using EC = ElementT; + using ED = ElementT; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int kAA = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAC = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAD = + 128 / cutlass::sizeof_bits::value; + using Acc = float; + using Arch = cutlass::arch::Sm90; + using Op = cutlass::arch::OpClassTensorOp; + using Tile = Shape<_128, _128, _64>; + using Cluster = Shape<_1, _1, _1>; + using SmTile = Shape<_128, _128, _64>; + + using Epi = + typename cutlass::epilogue::collective::CollectiveBuilder< + Arch, Op, SmTile, Cluster, + cutlass::epilogue::collective::EpilogueTileAuto, + Acc, Acc, EC, LayoutC, kAC, ED, LayoutD, kAD, + cutlass::epilogue::collective:: + EpilogueScheduleAuto>::CollectiveOp; + + using Main = + typename cutlass::gemm::collective::CollectiveBuilder< + Arch, Op, EA, LayoutA, kAA, EB, LayoutB, kAB, + Acc, Tile, Cluster, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast( + sizeof(typename Epi::SharedStorage))>, + cutlass::gemm::collective:: + KernelScheduleAuto>::CollectiveOp; + + using Kernel = cutlass::gemm::kernel::GemmUniversal< + Shape, Main, Epi, void>; + using Gemm = + cutlass::gemm::device::GemmUniversalAdapter; + using SA = typename Gemm::GemmKernel::StrideA; + using SB = typename Gemm::GemmKernel::StrideB; + using SC = typename Gemm::GemmKernel::StrideC; + using SD = typename Gemm::GemmKernel::StrideD; +}; + +template +void runGemmSm90( + at::Tensor& out, + const at::Tensor& a, + const at::Tensor& b, + int64_t m, int64_t n, int64_t k, + cudaStream_t stream) { + using C = TmaSm90Config; + auto sa = cutlass::make_cute_packed_stride( + typename C::SA{}, {(int)m, (int)k, 1}); + auto sb = cutlass::make_cute_packed_stride( + typename C::SB{}, {(int)k, (int)n, 1}); + auto sc = cutlass::make_cute_packed_stride( + typename C::SC{}, {(int)m, (int)n, 1}); + auto sd = cutlass::make_cute_packed_stride( + typename C::SD{}, {(int)m, (int)n, 1}); + typename C::Kernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + {(int)m, (int)n, (int)k, 1}, + {static_cast(a.data_ptr()), sa, + static_cast(b.data_ptr()), sb}, + {{}, nullptr, sc, + static_cast(out.data_ptr()), sd}}; + typename C::Gemm gemm; + size_t ws = C::Gemm::get_workspace_size(args); + auto wt = at::empty( + {(int64_t)ws}, + at::TensorOptions().dtype(at::kByte).device( + a.device())); + NVF_CHECK( + gemm.can_implement(args) == cutlass::Status::kSuccess, + "CUTLASS cannot implement this GEMM."); + NVF_CHECK( + gemm.initialize(args, wt.data_ptr(), stream) == + cutlass::Status::kSuccess, + "CUTLASS init failed."); + NVF_CHECK( + gemm.run(args, wt.data_ptr(), stream, nullptr, true) == + cutlass::Status::kSuccess, + "CUTLASS run failed."); } -__device__ inline void waitForEpochFromRank( - int32_t* local_semaphore, - int64_t row, - int64_t m, - int64_t writer_rank, - int32_t epoch) { - auto* rank_epoch_ptr = reinterpret_cast( - local_semaphore + (writer_rank * m + row) * kSemaphoreVecWidth); - int64_t spins = 0; - while (atomicAdd(rank_epoch_ptr, 0U) < static_cast(epoch)) { - ++spins; - if (spins > kMaxSemaphorePollIters) { - asm volatile("trap;"); - } - } -} +#else -__device__ inline void waitForEpochFromAllRanks( - int32_t* local_semaphore, - int64_t row, - int64_t m, - int64_t world_size, - int32_t epoch) { - for (int64_t rank = 0; rank < world_size; ++rank) { - auto* rank_epoch_ptr = reinterpret_cast( - local_semaphore + (rank * m + row) * kSemaphoreVecWidth); - int64_t spins = 0; - while (atomicAdd(rank_epoch_ptr, 0U) < static_cast(epoch)) { - ++spins; - if (spins > kMaxSemaphorePollIters) { - asm volatile("trap;"); - } - } - } +template +void runGemmSm90( + at::Tensor&, const at::Tensor&, const at::Tensor&, + int64_t, int64_t, int64_t, cudaStream_t) { + NVF_THROW("CUTLASS SM90 support required for TMA matmul."); } -template -double timeKernelLaunchesMs( - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream, - LaunchFn&& launch_once) { - for (int64_t i = 0; i < warmup_iters; ++i) { - launch_once(); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); +#endif - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); - for (int64_t i = 0; i < iters; ++i) { - launch_once(); +// ========================================================================= +// Section 2: Remote semaphore helpers +// +// Device-side helpers for inter-rank synchronization. Each semaphore +// row stores kVecW int32 epochs. Owner publishes; readers poll. +// ========================================================================= + +constexpr int64_t kVecW = 4; +constexpr int64_t kMaxPoll = 1LL << 26; + +__device__ inline void publishToAll( + int32_t* const* remote, int32_t* local, + int64_t writer, int64_t row, int64_t m, + int64_t ws, int32_t epoch) { + int32_t* my = local + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) my[i] = epoch; + __threadfence_system(); + for (int64_t p = 0; p < ws; ++p) { + int32_t* d = remote[p] + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; } - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); - - float total_ms = 0.0f; - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); - return static_cast(total_ms) / static_cast(iters); + __threadfence_system(); } -template -double timeCommThenCutlassMs( - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream, - const at::Tensor& a_comm, - const at::Tensor& b_full, - at::Tensor& c_out, - CommLaunchFn&& launch_comm_once) { - NVF_CHECK( - canRunMatmulTma(a_comm, b_full), - "CUTLASS TMA compute requires Hopper+ and compatible half inputs."); - auto launch_once = [&]() { - launch_comm_once(); - // Rebind output tensor to TMA matmul result (avoids an extra device copy). - c_out = matmulTma(a_comm, b_full); - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +__device__ inline void publishToOne( + int32_t* target, int64_t writer, + int64_t row, int64_t m, int32_t epoch) { + int32_t* d = target + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; + __threadfence_system(); } -void asyncAllgatherMatmulTmaOutLocal( - at::Tensor& out, - const at::Tensor& a, - const at::Tensor& b, - const at::Tensor& a_chunk_signals, - int64_t a_chunk_pivot) { - NVF_CHECK( - canRunMatmulTma(a, b), - "Async CUTLASS compute requires TMA-compatible A/B."); - NVF_CHECK( - a_chunk_signals.defined() && a_chunk_signals.is_cuda() && - a_chunk_signals.scalar_type() == at::ScalarType::Int && - a_chunk_signals.dim() == 1 && a_chunk_signals.numel() > 0, - "Async CUTLASS compute requires CUDA int32 chunk signals."); - NVF_CHECK( - a.size(0) % a_chunk_signals.numel() == 0, - "A rows must be divisible by chunk-signals length."); - NVF_CHECK( - a_chunk_pivot >= 0 && a_chunk_pivot <= a_chunk_signals.numel(), - "Chunk pivot must be in [0, num_chunks]."); - out.copy_(matmulTma(a, b)); +__device__ inline void setLocal( + int32_t* local, int64_t writer, + int64_t row, int64_t m, int32_t epoch) { + int32_t* d = local + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; + __threadfence_system(); } -// Naive fused kernel: -// - A is row-sharded across ranks (axis M) -// - each output row reads from its owner rank shard via remote pointers -// - B is replicated -__global__ void fusedRemoteMatmulKernel( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank) { - const int64_t row = blockIdx.y * blockDim.y + threadIdx.y; - const int64_t col = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= m || col >= n) { - return; - } - - // Map global row to the rank-local row in that rank's A shard. - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; +__device__ inline void waitOne( + int32_t* local, int64_t row, int64_t m, + int64_t writer, int32_t epoch) { + auto* p = reinterpret_cast( + local + (writer * m + row) * kVecW); + int64_t s = 0; + while (atomicAdd(p, 0U) < (unsigned)epoch) + if (++s > kMaxPoll) asm volatile("trap;"); +} - float acc = 0.0f; - for (int64_t kk = 0; kk < k; ++kk) { - const float a = __half2float(a_local[local_row * k + kk]); - const float b = __half2float(b_full[kk * n + col]); - acc += a * b; +__device__ inline void waitAll( + int32_t* local, int64_t row, int64_t m, + int64_t ws, int32_t epoch) { + for (int64_t r = 0; r < ws; ++r) { + auto* p = reinterpret_cast( + local + (r * m + row) * kVecW); + int64_t s = 0; + while (atomicAdd(p, 0U) < (unsigned)epoch) + if (++s > kMaxPoll) asm volatile("trap;"); } - c_out[row * n + col] = __float2half(acc); } -// Fused kernel with explicit internal stages: -// 1) Allgather stage: materialize one full A row from remote shard. -// 2) Compute stage: matmul for that row across all output columns. -__global__ void fusedStagedThreadLoadKernel( - const __half* const* a_remote_shards, - __half* a_gathered, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - const __half* b_full, - __half* c_out) { - for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - - // Stage 1: gather this row into staged global buffer. - for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) { - a_gathered[row * k + kk] = a_local[local_row * k + kk]; - } - __syncthreads(); - - // Stage 2: compute this row from staged global A. - for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { - float acc = 0.0f; - for (int64_t kk = 0; kk < k; ++kk) { - acc += __half2float(a_gathered[row * k + kk]) * - __half2float(b_full[kk * n + col]); - } - c_out[row * n + col] = __float2half(acc); - } - __syncthreads(); - } +// ========================================================================= +// Section 3: Kernel definitions +// ========================================================================= + +// --- 3a. naiveRemoteReadKernel --- +// Each thread computes one C[row,col]. A is read directly from the +// owner rank's shard via remote pointers -- no staging, no gather. +__global__ void naiveRemoteReadKernel( + const __half* const* a_shards, + const __half* b, __half* c, + int64_t m, int64_t n, int64_t k, + int64_t m_per_rank) { + int64_t row = blockIdx.y * blockDim.y + threadIdx.y; + int64_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= m || col >= n) return; + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) + acc += __half2float(a[lr * k + kk]) * + __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); } -__global__ void fusedStagedThreadLoadSynchronizedKernel( - const __half* const* a_remote_shards, +// --- 3b. threadloadGatherKernel --- +// Two-stage fused kernel with synchronized P2P gather: +// Stage 1: cooperative thread loads copy one A row from the owner +// rank's remote shard into a local staging buffer. +// Stage 2: scalar matmul from staged A. Skipped when n==0 +// (CUTLASS variants launch host-side CUTLASS instead). +// Owner signals readiness; non-owners wait. After compute, readers +// ack completion; owner waits for all readers. +__global__ void threadloadGatherKernel( + const __half* const* a_shards, __half* a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - int32_t launch_epoch_base, - int64_t m, - int64_t n, - int64_t k, + int32_t* const* ready_r, int32_t* ready_l, + int32_t* const* done_r, int32_t* done_l, + int64_t rank, int64_t ws, int32_t epoch_base, + int64_t m, int64_t n, int64_t k, int64_t m_per_rank, - const __half* b_full, - __half* c_out) { - const int32_t launch_epoch = launch_epoch_base + 1; - for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - if (threadIdx.x == 0) { - // Owner publishes "ready"; non-owners wait only on owner readiness. - if (my_rank == owner_rank) { - publishEpochToAllRanks( - ready_semaphore_remote_ptrs, - ready_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); - } - } + const __half* b, __half* c) { + const int32_t epoch = epoch_base + 1; + for (int64_t row = blockIdx.x; row < m; + row += gridDim.x) { + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + + // --- Semaphore: owner signals readiness --- + if (threadIdx.x == 0 && rank == owner) + publishToAll( + ready_r, ready_l, rank, row, m, ws, epoch); __syncthreads(); - - if (threadIdx.x == 0 && my_rank != owner_rank) { - waitForEpochFromRank( - ready_semaphore_local, row, m, owner_rank, launch_epoch); - } + if (threadIdx.x == 0 && rank != owner) + waitOne(ready_l, row, m, owner, epoch); __syncthreads(); - // Stage 1: gather this row into staged global buffer. - for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) { - a_gathered[row * k + kk] = a_local[local_row * k + kk]; - } + // --- Stage 1: P2P gather via thread loads --- + for (int64_t kk = threadIdx.x; kk < k; + kk += blockDim.x) + a_gathered[row * k + kk] = a[lr * k + kk]; __syncthreads(); - // Stage 2: compute this row from staged global A. - for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { - float acc = 0.0f; - for (int64_t kk = 0; kk < k; ++kk) { + // --- Stage 2: scalar matmul (skip when n==0) --- + for (int64_t col = threadIdx.x; col < n; + col += blockDim.x) { + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) acc += __half2float(a_gathered[row * k + kk]) * - __half2float(b_full[kk * n + col]); - } - c_out[row * n + col] = __float2half(acc); + __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); } __syncthreads(); + // --- Semaphore: readers ack, owner waits --- if (threadIdx.x == 0) { - // Readers ack completion only to owner; owner waits on all readers. - if (my_rank == owner_rank) { - setLocalEpoch(done_semaphore_local, my_rank, row, m, launch_epoch); - } else { - publishEpochToRank( - done_semaphore_remote_ptrs[owner_rank], - my_rank, - row, - m, - launch_epoch); - } + if (rank == owner) + setLocal(done_l, rank, row, m, epoch); + else + publishToOne(done_r[owner], rank, row, m, epoch); } __syncthreads(); - - if (threadIdx.x == 0 && my_rank == owner_rank) { - waitForEpochFromAllRanks( - done_semaphore_local, row, m, world_size, launch_epoch); - } + if (threadIdx.x == 0 && rank == owner) + waitAll(done_l, row, m, ws, epoch); __syncthreads(); } } -// Same fused structure as above, but stage-1 writes use multimem stores. -__global__ void fusedStagedMultimemKernel( - const __half* const* a_remote_shards, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - int32_t launch_epoch_base, - int64_t m, - int64_t n, - int64_t k, +// --- 3c. multimemGatherKernel --- +// Two-stage fused kernel using Hopper multimem stores: +// Stage 1: the owner rank writes each A row to a multicast buffer +// via multimem.st.global.v4.f32 (hardware broadcast to all peers). +// Stage 2: scalar matmul from multicast buffer. Skipped when n==0. +// Requires SM90+ and multicast-capable symmetric memory. +__global__ void multimemGatherKernel( + const __half* const* a_shards, + __half* a_mc, + int32_t* const* sem_r, int32_t* sem_l, + int64_t rank, int64_t ws, int32_t epoch_base, + int64_t m, int64_t n, int64_t k, int64_t m_per_rank, - const __half* b_full, - __half* c_out) { - for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - __half* a_row_stage = a_gathered_multicast + row * k; - - // Owner materializes the multicast row; non-owners wait on owner readiness. - constexpr int64_t vec_elems = 8; // 8 * half = 16 bytes - const int64_t n_vec = k / vec_elems; - if (my_rank == owner_rank) { - for (int64_t vec_i = threadIdx.x; vec_i < n_vec; vec_i += blockDim.x) { - const uint4 val = - reinterpret_cast(a_local + local_row * k)[vec_i]; + const __half* b, __half* c) { + for (int64_t row = blockIdx.x; row < m; + row += gridDim.x) { + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + __half* arow = a_mc + row * k; + + // --- Stage 1: multimem store (owner only) --- + constexpr int64_t kVec = 8; + int64_t nvec = k / kVec; + if (rank == owner) { + for (int64_t vi = threadIdx.x; vi < nvec; + vi += blockDim.x) { + uint4 val = reinterpret_cast( + a + lr * k)[vi]; #if __CUDA_ARCH__ >= 900 - asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"((void*)(a_row_stage + vec_i * vec_elems)), - "f"(__int_as_float(static_cast(val.x))), - "f"(__int_as_float(static_cast(val.y))), - "f"(__int_as_float(static_cast(val.z))), - "f"(__int_as_float(static_cast(val.w))) - : "memory"); + asm volatile( + "multimem.st.global.v4.f32 [%0]," + " {%1, %2, %3, %4};" + : + : "l"((void*)(arow + vi * kVec)), + "f"(__int_as_float((int)val.x)), + "f"(__int_as_float((int)val.y)), + "f"(__int_as_float((int)val.z)), + "f"(__int_as_float((int)val.w)) + : "memory"); #else (void)val; - // Multimem path must never run on non-Hopper architectures. asm volatile("trap;"); #endif } - for (int64_t kk = n_vec * vec_elems + threadIdx.x; kk < k; - kk += blockDim.x) { - a_row_stage[kk] = a_local[local_row * k + kk]; - } + for (int64_t kk = nvec * kVec + threadIdx.x; + kk < k; kk += blockDim.x) + arow[kk] = a[lr * k + kk]; } __syncthreads(); + // --- Semaphore barrier --- #if __CUDA_ARCH__ >= 900 - // Cross-device barrier between stage-1 stores and stage-2 reads. - // Each rank publishes one epoch for (my_rank,row), then waits until - // all writer ranks have published the same epoch for that row. - const int32_t launch_epoch = launch_epoch_base + 1; - - if (threadIdx.x == 0) { - if (my_rank == owner_rank) { - publishEpochToAllRanks( - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); - } - } + const int32_t epoch = epoch_base + 1; + if (threadIdx.x == 0 && rank == owner) + publishToAll( + sem_r, sem_l, rank, row, m, ws, epoch); __syncthreads(); - - if (threadIdx.x == 0 && my_rank != owner_rank) { - // Non-owners only need owner readiness before reading multicast row. - waitForEpochFromRank( - stage_semaphore_local, row, m, owner_rank, launch_epoch); - } + if (threadIdx.x == 0 && rank != owner) + waitOne(sem_l, row, m, owner, epoch); __syncthreads(); #else - (void)stage_semaphore_remote_ptrs; - (void)stage_semaphore_local; - (void)my_rank; - (void)world_size; - (void)launch_epoch_base; + (void)sem_r; (void)sem_l; + (void)rank; (void)ws; (void)epoch_base; asm volatile("trap;"); #endif - // Stage 2: compute from staged multicast-backed row. - for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { - float acc = 0.0f; - for (int64_t kk = 0; kk < k; ++kk) { - acc += - __half2float(a_row_stage[kk]) * __half2float(b_full[kk * n + col]); - } - c_out[row * n + col] = __float2half(acc); + // --- Stage 2: scalar matmul (skip when n==0) --- + for (int64_t col = threadIdx.x; col < n; + col += blockDim.x) { + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) + acc += __half2float(arow[kk]) * + __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); } __syncthreads(); } } -__global__ void fusedNaiveFusedTmaKernel( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank) { -#if __CUDA_ARCH__ >= 900 - if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { - return; - } - const int lane = static_cast(threadIdx.x); - const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; - const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; - if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { - return; - } - - __shared__ __half a_tile[kMmaTile * kMmaTile]; - __shared__ __half b_tile[kMmaTile * kMmaTile]; - using namespace nvcuda; - wmma::fragment acc_frag; - wmma::fill_fragment(acc_frag, 0.0f); - - for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - const int64_t row = row_base + r; - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - a_tile[idx] = a_local[local_row * k + (kk0 + c)]; - b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; - } - __syncwarp(); - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); - wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); - wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); - __syncwarp(); - } - - __shared__ float c_tile[kMmaTile * kMmaTile]; - wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); - __syncwarp(); - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); - } -#else - (void)a_remote_shards; - (void)b_full; - (void)c_out; - (void)m; - (void)n; - (void)k; - (void)m_per_rank; - asm volatile("trap;"); -#endif +} // anonymous namespace + +// ========================================================================= +// Section 4: Public launcher functions +// +// Thin wrappers that set up grid/block dims and launch kernels once. +// Timing and iteration loops live in the .cpp file. +// ========================================================================= + +void launchNaiveRemoteRead(DistributedMatmulContext& ctx) { + constexpr int64_t kB = 16; + dim3 block(kB, kB); + dim3 grid( + (ctx.n + kB - 1) / kB, (ctx.m + kB - 1) / kB); + naiveRemoteReadKernel<<>>( + ctx.device_remote_ptrs, + reinterpret_cast( + ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>( + ctx.c_out_half.data_ptr()), + ctx.m, ctx.n, ctx.k, ctx.m_per_rank); } -__global__ void fusedStagedThreadLoadSynchronizedFusedTmaKernel( - const __half* const* a_remote_shards, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - int32_t launch_epoch_base, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank) { -#if __CUDA_ARCH__ >= 900 - if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { - return; - } - const int lane = static_cast(threadIdx.x); - const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; - const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; - if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { - return; - } - const int32_t launch_epoch = launch_epoch_base + 1; - - for (int64_t r = 0; r < kMmaTile; ++r) { - const int64_t row = row_base + r; - const int64_t owner_rank = row / m_per_rank; - if (lane == 0) { - if (my_rank == owner_rank) { - publishEpochToAllRanks( - ready_semaphore_remote_ptrs, - ready_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); - } - } - __syncwarp(); - if (lane == 0 && my_rank != owner_rank) { - waitForEpochFromRank( - ready_semaphore_local, row, m, owner_rank, launch_epoch); - } - __syncwarp(); - } - - __shared__ __half a_tile[kMmaTile * kMmaTile]; - __shared__ __half b_tile[kMmaTile * kMmaTile]; - using namespace nvcuda; - wmma::fragment acc_frag; - wmma::fill_fragment(acc_frag, 0.0f); - for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - const int64_t row = row_base + r; - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - a_tile[idx] = a_local[local_row * k + (kk0 + c)]; - b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; - } - __syncwarp(); - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); - wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); - wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); - __syncwarp(); - } +void launchThreadloadGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute) { + dim3 block(256); + dim3 grid(ctx.m); + threadloadGatherKernel<<>>( + ctx.device_remote_ptrs, + reinterpret_cast<__half*>( + ctx.a_gathered.data_ptr()), + ctx.ready_sem_remote, ctx.ready_sem_local, + ctx.done_sem_remote, ctx.done_sem_local, + ctx.my_rank, ctx.world_size, epoch, + ctx.m, compute ? ctx.n : 0, + ctx.k, ctx.m_per_rank, + reinterpret_cast( + ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>( + ctx.c_out_half.data_ptr())); +} - __shared__ float c_tile[kMmaTile * kMmaTile]; - wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); - __syncwarp(); - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); - } - __syncwarp(); +void launchMultimemGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute) { + dim3 block(256); + dim3 grid(ctx.m); + multimemGatherKernel<<>>( + ctx.device_remote_ptrs, + ctx.multicast_ptr, + ctx.stage_sem_remote, ctx.stage_sem_local, + ctx.my_rank, ctx.world_size, epoch, + ctx.m, compute ? ctx.n : 0, + ctx.k, ctx.m_per_rank, + reinterpret_cast( + ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>( + ctx.c_out_half.data_ptr())); +} - for (int64_t r = 0; r < kMmaTile; ++r) { - const int64_t row = row_base + r; - const int64_t owner_rank = row / m_per_rank; - if (lane == 0) { - if (my_rank == owner_rank) { - setLocalEpoch(done_semaphore_local, my_rank, row, m, launch_epoch); - } else { - publishEpochToRank( - done_semaphore_remote_ptrs[owner_rank], - my_rank, - row, - m, - launch_epoch); - } - } - __syncwarp(); - if (lane == 0 && my_rank == owner_rank) { - waitForEpochFromAllRanks( - done_semaphore_local, row, m, world_size, launch_epoch); - } - __syncwarp(); - } +at::Tensor matmulTma( + const at::Tensor& a, + const at::Tensor& b) { + NVF_CHECK(a.is_cuda() && b.is_cuda()); + NVF_CHECK(a.dim() == 2 && b.dim() == 2); + NVF_CHECK(a.size(1) == b.size(0)); + at::cuda::CUDAGuard guard{a.device()}; + auto* props = + at::cuda::getDeviceProperties(a.get_device()); + NVF_CHECK(props->major >= 9, "Requires Hopper+."); + int64_t m = a.size(0), n = b.size(1), k = a.size(1); + at::Tensor out = at::empty( + {m, n}, + at::TensorOptions() + .dtype(a.scalar_type()) + .device(a.device())); + cudaStream_t stream = + at::cuda::getCurrentCUDAStream(a.get_device()); +#if defined(NVFUSER_ENABLE_CUTLASS) + if (a.scalar_type() == at::ScalarType::Half) + runGemmSm90( + out, a, b, m, n, k, stream); + else + runGemmSm90( + out, a, b, m, n, k, stream); #else - (void)a_remote_shards; - (void)ready_semaphore_remote_ptrs; - (void)ready_semaphore_local; - (void)done_semaphore_remote_ptrs; - (void)done_semaphore_local; - (void)my_rank; - (void)world_size; - (void)launch_epoch_base; - (void)b_full; - (void)c_out; - (void)m; - (void)n; - (void)k; - (void)m_per_rank; - asm volatile("trap;"); + NVF_THROW("CUTLASS support required."); #endif + return out; } -__global__ void fusedStagedMultimemFusedTmaKernel( - const __half* const* a_remote_shards, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - int32_t launch_epoch_base, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank) { -#if __CUDA_ARCH__ >= 900 - if (threadIdx.x >= warpSize || threadIdx.y != 0 || threadIdx.z != 0) { - return; - } - const int lane = static_cast(threadIdx.x); - const int64_t row_base = static_cast(blockIdx.y) * kMmaTile; - const int64_t col_base = static_cast(blockIdx.x) * kMmaTile; - if (row_base + kMmaTile > m || col_base + kMmaTile > n || k % kMmaTile != 0) { - return; - } - - constexpr int64_t vec_elems = 8; - for (int64_t r = 0; r < kMmaTile; ++r) { - const int64_t row = row_base + r; - const int64_t owner_rank = row / m_per_rank; - const int64_t local_row = row - owner_rank * m_per_rank; - const __half* a_local = a_remote_shards[owner_rank]; - __half* a_row_stage = a_gathered_multicast + row * k; - if (my_rank == owner_rank) { - const int64_t n_vec = k / vec_elems; - for (int64_t vec_i = lane; vec_i < n_vec; vec_i += warpSize) { - const uint4 val = - reinterpret_cast(a_local + local_row * k)[vec_i]; - asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"((void*)(a_row_stage + vec_i * vec_elems)), - "f"(__int_as_float(static_cast(val.x))), - "f"(__int_as_float(static_cast(val.y))), - "f"(__int_as_float(static_cast(val.z))), - "f"(__int_as_float(static_cast(val.w))) - : "memory"); - } - for (int64_t kk = (k / vec_elems) * vec_elems + lane; kk < k; - kk += warpSize) { - a_row_stage[kk] = a_local[local_row * k + kk]; - } - } - __syncwarp(); - - const int32_t launch_epoch = launch_epoch_base + 1; - if (lane == 0) { - if (my_rank == owner_rank) { - publishEpochToAllRanks( - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - row, - m, - world_size, - launch_epoch); - } - } - __syncwarp(); - if (lane == 0 && my_rank != owner_rank) { - waitForEpochFromRank( - stage_semaphore_local, row, m, owner_rank, launch_epoch); - } - __syncwarp(); - } - - __shared__ __half a_tile[kMmaTile * kMmaTile]; - __shared__ __half b_tile[kMmaTile * kMmaTile]; - using namespace nvcuda; - wmma::fragment acc_frag; - wmma::fill_fragment(acc_frag, 0.0f); - for (int64_t kk0 = 0; kk0 < k; kk0 += kMmaTile) { - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - const int64_t row = row_base + r; - a_tile[idx] = a_gathered_multicast[row * k + (kk0 + c)]; - b_tile[idx] = b_full[(kk0 + r) * n + (col_base + c)]; - } - __syncwarp(); - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::load_matrix_sync(a_frag, a_tile, kMmaTile); - wmma::load_matrix_sync(b_frag, b_tile, kMmaTile); - wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); - __syncwarp(); - } - __shared__ float c_tile[kMmaTile * kMmaTile]; - wmma::store_matrix_sync(c_tile, acc_frag, kMmaTile, wmma::mem_row_major); - __syncwarp(); - for (int64_t idx = lane; idx < kMmaTile * kMmaTile; idx += warpSize) { - const int64_t r = idx / kMmaTile; - const int64_t c = idx % kMmaTile; - c_out[(row_base + r) * n + (col_base + c)] = __float2half(c_tile[idx]); - } +bool canRunCutlassCompute( + const at::Tensor& a, + const at::Tensor& b) { + if (!hasValidTmaShape(a, b)) return false; +#if !defined(NVFUSER_ENABLE_CUTLASS) || \ + !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + return false; #else - (void)a_remote_shards; - (void)a_gathered_multicast; - (void)stage_semaphore_remote_ptrs; - (void)stage_semaphore_local; - (void)my_rank; - (void)world_size; - (void)launch_epoch_base; - (void)b_full; - (void)c_out; - (void)m; - (void)n; - (void)k; - (void)m_per_rank; - asm volatile("trap;"); + auto* props = + at::cuda::getDeviceProperties(a.get_device()); + return props->major == 9 && props->minor == 0; #endif } -} // namespace - -double timeFusedRemoteMatmulMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_x, - int64_t block_y, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - (void)world_size; - const dim3 block(static_cast(block_x), static_cast(block_y)); - const dim3 grid( - static_cast((n + block.x - 1) / block.x), - static_cast((m + block.y - 1) / block.y)); - - auto launch_once = [&]() { - fusedRemoteMatmulKernel<<>>( - a_remote_shards, b_full, c_out, m, n, k, m_per_rank); - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeNaiveRemoteMatmulCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - const __half* b_ptr = - reinterpret_cast(b_full.data_ptr()); - __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); - __half* a_gathered_ptr = - reinterpret_cast<__half*>(a_gathered.data_ptr()); - auto launch_comm_once = [&]() { - // Keep communication as remote thread-load gather. - fusedStagedThreadLoadKernel<<>>( - a_remote_shards, - a_gathered_ptr, - m, - /*n=*/0, - k, - m_per_rank, - b_ptr, - c_ptr); - }; - return timeCommThenCutlassMs( - warmup_iters, - iters, - stream, - a_gathered, - b_full, - c_out, - launch_comm_once); -} - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - int64_t launch_epoch_base = 0; - - auto launch_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "ThreadLoad synchronized semaphore epoch overflow."); - fusedStagedThreadLoadSynchronizedKernel<<>>( - a_remote_shards, - a_gathered, - ready_semaphore_remote_ptrs, - ready_semaphore_local, - done_semaphore_remote_ptrs, - done_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - m, - n, - k, - m_per_rank, - b_full, - c_out); - ++launch_epoch_base; - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - int64_t launch_epoch_base = 0; - const __half* b_ptr = - reinterpret_cast(b_full.data_ptr()); - __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); - __half* a_gathered_ptr = - reinterpret_cast<__half*>(a_gathered.data_ptr()); - auto launch_comm_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "ThreadLoad synchronized CUTLASS semaphore epoch overflow."); - fusedStagedThreadLoadSynchronizedKernel<<>>( - a_remote_shards, - a_gathered_ptr, - ready_semaphore_remote_ptrs, - ready_semaphore_local, - done_semaphore_remote_ptrs, - done_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - m, - /*n=*/0, - k, - m_per_rank, - b_ptr, - c_ptr); - ++launch_epoch_base; - }; - return timeCommThenCutlassMs( - warmup_iters, - iters, - stream, - a_gathered, - b_full, - c_out, - launch_comm_once); -} - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedAsyncCutlassMs( - const __half* const* a_remote_shards, - at::Tensor& a_gathered, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - at::Tensor& a_chunk_signals, - int64_t a_chunk_pivot, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - int64_t launch_epoch_base = 0; - const __half* b_ptr = - reinterpret_cast(b_full.data_ptr()); - __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); - __half* a_gathered_ptr = - reinterpret_cast<__half*>(a_gathered.data_ptr()); - auto launch_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "ThreadLoad synchronized async CUTLASS semaphore epoch overflow."); - fusedStagedThreadLoadSynchronizedKernel<<>>( - a_remote_shards, - a_gathered_ptr, - ready_semaphore_remote_ptrs, - ready_semaphore_local, - done_semaphore_remote_ptrs, - done_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - m, - /*n=*/0, - k, - m_per_rank, - b_ptr, - c_ptr); - - a_chunk_signals.fill_(1); - asyncAllgatherMatmulTmaOutLocal( - c_out, a_gathered, b_full, a_chunk_signals, a_chunk_pivot); - ++launch_epoch_base; - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeSeparatedAllgatherMatmulMultimemMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - int64_t launch_epoch_base = 0; - - auto launch_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "Multimem semaphore epoch overflow."); - fusedStagedMultimemKernel<<>>( - a_remote_shards, - a_gathered_multicast, - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - m, - n, - k, - m_per_rank, - b_full, - c_out); - ++launch_epoch_base; - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeSeparatedAllgatherMatmulMultimemCutlassMs( - const __half* const* a_remote_shards, - __half* a_gathered_multicast_ptr, - const at::Tensor& a_gathered_local, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - const at::Tensor& b_full, - at::Tensor& c_out, - int64_t m, - int64_t k, - int64_t m_per_rank, - int64_t block_threads, - int64_t grid_blocks, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - const dim3 block(static_cast(block_threads)); - const dim3 grid(static_cast(grid_blocks <= 0 ? m : grid_blocks)); - int64_t launch_epoch_base = 0; - const __half* b_ptr = - reinterpret_cast(b_full.data_ptr()); - __half* c_ptr = reinterpret_cast<__half*>(c_out.data_ptr()); - auto launch_comm_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "Multimem CUTLASS semaphore epoch overflow."); - fusedStagedMultimemKernel<<>>( - a_remote_shards, - a_gathered_multicast_ptr, - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - m, - /*n=*/0, - k, - m_per_rank, - b_ptr, - c_ptr); - ++launch_epoch_base; - }; - return timeCommThenCutlassMs( - warmup_iters, - iters, - stream, - a_gathered_local, - b_full, - c_out, - launch_comm_once); -} - -double timeNaiveRemoteMatmulFusedTmaMs( - const __half* const* a_remote_shards, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - NVF_CHECK( - m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, - "FusedTma kernels require M,N,K divisible by 16."); - const dim3 block(32); - const dim3 grid( - static_cast(n / kMmaTile), - static_cast(m / kMmaTile)); - auto launch_once = [&]() { - fusedNaiveFusedTmaKernel<<>>( - a_remote_shards, b_full, c_out, m, n, k, m_per_rank); - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); -} - -double timeSeparatedAllgatherMatmulThreadLoadSynchronizedFusedTmaMs( - const __half* const* a_remote_shards, - int32_t* const* ready_semaphore_remote_ptrs, - int32_t* ready_semaphore_local, - int32_t* const* done_semaphore_remote_ptrs, - int32_t* done_semaphore_local, - int64_t my_rank, - int64_t world_size, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - NVF_CHECK( - m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, - "FusedTma kernels require M,N,K divisible by 16."); - const dim3 block(32); - const dim3 grid( - static_cast(n / kMmaTile), - static_cast(m / kMmaTile)); - int64_t launch_epoch_base = 0; - auto launch_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "FusedTma threadload semaphore epoch overflow."); - fusedStagedThreadLoadSynchronizedFusedTmaKernel<<>>( - a_remote_shards, - ready_semaphore_remote_ptrs, - ready_semaphore_local, - done_semaphore_remote_ptrs, - done_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - b_full, - c_out, - m, - n, - k, - m_per_rank); - ++launch_epoch_base; - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +const char* implName(DistributedMatmulImpl impl) { + switch (impl) { + case DistributedMatmulImpl::baselineNcclAllgatherMatmul: + return "baselineNcclAllgatherMatmul"; + case DistributedMatmulImpl::baselineCudaAllgatherMatmul: + return "baselineCudaAllgatherMatmul"; + case DistributedMatmulImpl::naiveRemoteRead: + return "naiveRemoteRead"; + case DistributedMatmulImpl::threadloadGatherScalarCompute: + return "threadloadGatherScalarCompute"; + case DistributedMatmulImpl::multimemGatherScalarCompute: + return "multimemGatherScalarCompute"; + case DistributedMatmulImpl::threadloadGatherCutlassCompute: + return "threadloadGatherCutlassCompute"; + case DistributedMatmulImpl::multimemGatherCutlassCompute: + return "multimemGatherCutlassCompute"; + } + return "unknown"; } -double timeSeparatedAllgatherMatmulMultimemFusedTmaMs( - const __half* const* a_remote_shards, - __half* a_gathered_multicast, - int32_t* const* stage_semaphore_remote_ptrs, - int32_t* stage_semaphore_local, - int64_t my_rank, - int64_t world_size, - const __half* b_full, - __half* c_out, - int64_t m, - int64_t n, - int64_t k, - int64_t m_per_rank, - int64_t warmup_iters, - int64_t iters, - cudaStream_t stream) { - NVF_CHECK( - m % kMmaTile == 0 && n % kMmaTile == 0 && k % kMmaTile == 0, - "FusedTma kernels require M,N,K divisible by 16."); - const dim3 block(32); - const dim3 grid( - static_cast(n / kMmaTile), - static_cast(m / kMmaTile)); - int64_t launch_epoch_base = 0; - auto launch_once = [&]() { - NVF_CHECK( - launch_epoch_base < std::numeric_limits::max(), - "FusedTma multimem semaphore epoch overflow."); - fusedStagedMultimemFusedTmaKernel<<>>( - a_remote_shards, - a_gathered_multicast, - stage_semaphore_remote_ptrs, - stage_semaphore_local, - my_rank, - world_size, - static_cast(launch_epoch_base), - b_full, - c_out, - m, - n, - k, - m_per_rank); - ++launch_epoch_base; - }; - return timeKernelLaunchesMs(warmup_iters, iters, stream, launch_once); +bool isMulticastSupported(int64_t device_id) { + int val = 0; + auto r = cuDeviceGetAttribute( + &val, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + static_cast(device_id)); + return r == CUDA_SUCCESS && val != 0; } } // namespace nvfuser From 1c82f61ca1913a00235e76a8b2d16ddbe310db40 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Feb 2026 07:13:31 -0800 Subject: [PATCH 20/23] add header and batch cuda events timing --- .../test_multidevice_fused_remote_matmul.cpp | 108 +++++++++------ .../test_multidevice_fused_remote_matmul.h | 128 ++++++++++++++++++ 2 files changed, 198 insertions(+), 38 deletions(-) create mode 100644 tests/cpp/test_multidevice_fused_remote_matmul.h diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index fe3556f9465..d4eb8056235 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -43,9 +43,43 @@ namespace nvfuser { namespace { // ========================================================================= -// Timing helper +// Timing helpers // ========================================================================= +// Batched GPU timing: one cuda-event pair around all iterations. +// No per-iteration host sync, matching the original kernel timing +// methodology. Avoids ~15us cudaStreamSynchronize overhead that +// dominates sub-100us kernels like CUTLASS TMA matmuls. +template +double batchedKernelTimeMs( + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream, + Fn&& run_once) { + for (int64_t i = 0; i < warmup_iters; ++i) + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + cudaEvent_t start, stop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + float total_ms; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / iters; +} + +// Per-iteration timing for baselines with host-blocking waits. template double benchmarkLoopMs( const BenchmarkConfig& config, @@ -263,6 +297,8 @@ double runImplementation( DistributedMatmulContext& ctx, const BenchmarkConfig& config) { using I = DistributedMatmulImpl; + const int64_t wu = config.warmup_iters; + const int64_t it = config.iters; switch (impl) { case I::baselineNcclAllgatherMatmul: { at::Tensor a_full = at::empty( @@ -294,55 +330,51 @@ double runImplementation( config, ctx.communicator, ctx.stream, run); } case I::naiveRemoteRead: { - auto run = [&]() { - launchNaiveRemoteRead(ctx); - }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { + launchNaiveRemoteRead(ctx); + }); } case I::threadloadGatherScalarCompute: { int64_t epoch = 0; - auto run = [&]() { - launchThreadloadGather( - ctx, static_cast(epoch), true); - ++epoch; - }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { + launchThreadloadGather( + ctx, static_cast(epoch), true); + ++epoch; + }); } case I::threadloadGatherCutlassCompute: { int64_t epoch = 0; - auto run = [&]() { - launchThreadloadGather( - ctx, static_cast(epoch), false); - ctx.c_out_half = - matmulTma(ctx.a_gathered, ctx.b_full_half); - ++epoch; - }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { + launchThreadloadGather( + ctx, static_cast(epoch), false); + ctx.c_out_half = matmulTma( + ctx.a_gathered, ctx.b_full_half); + ++epoch; + }); } case I::multimemGatherScalarCompute: { int64_t epoch = 0; - auto run = [&]() { - launchMultimemGather( - ctx, static_cast(epoch), true); - ++epoch; - }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { + launchMultimemGather( + ctx, static_cast(epoch), true); + ++epoch; + }); } case I::multimemGatherCutlassCompute: { int64_t epoch = 0; - auto run = [&]() { - launchMultimemGather( - ctx, static_cast(epoch), false); - ctx.c_out_half = matmulTma( - ctx.a_gathered_multimem, ctx.b_full_half); - ++epoch; - }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { + launchMultimemGather( + ctx, static_cast(epoch), false); + ctx.c_out_half = matmulTma( + ctx.a_gathered_multimem, + ctx.b_full_half); + ++epoch; + }); } } NVF_ERROR(false, "Unknown implementation."); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.h b/tests/cpp/test_multidevice_fused_remote_matmul.h new file mode 100644 index 00000000000..1404d624af1 --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul.h @@ -0,0 +1,128 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +#include + +namespace c10d { +class Backend; +} + +namespace nvfuser { + +class Communicator; +class Communication; +class SymMemForAllgather; + +// ========================================================================= +// Distributed Matmul Benchmark -- shared types +// +// Computes C[M,N] = A[M,K] x B[K,N] where A is row-sharded across ranks +// on axis M and B is replicated. Each implementation varies the +// communication strategy (how A shards are gathered) and the compute +// strategy (how the matmul is performed). +// +// Performance on 8xH100 DGX (M=N=K=1024, half precision): +// +// Implementation | TFLOP/s +// ---------------------------------------+-------- +// baselineNcclAllgatherMatmul | 39.9 +// baselineCudaAllgatherMatmul | 24.1 +// naiveRemoteRead | 2.67 +// threadloadGatherScalarCompute | 4.05 +// multimemGatherScalarCompute | 3.69 +// threadloadGatherCutlassCompute | 46.5 +// multimemGatherCutlassCompute | 45.6 +// ========================================================================= + +enum class DistributedMatmulImpl { + // -- Baselines: separate allgather then PyTorch eager matmul -- + baselineNcclAllgatherMatmul, + baselineCudaAllgatherMatmul, + // -- Fused kernels: comm + scalar matmul in one kernel -- + naiveRemoteRead, + threadloadGatherScalarCompute, + multimemGatherScalarCompute, + // -- Two-phase: handwritten comm kernel + CUTLASS TMA matmul -- + threadloadGatherCutlassCompute, + multimemGatherCutlassCompute, +}; + +enum class TimeMeasurementMode { CudaEvents, CpuClock }; + +struct BenchmarkConfig { + int64_t warmup_iters; + int64_t iters; + TimeMeasurementMode time_mode; + bool barrier_at_each_iteration; +}; + +// All data any implementation may need. Unused fields are +// null/undefined for a given implementation. +struct DistributedMatmulContext { + // Problem dimensions + int64_t m = 0, n = 0, k = 0, m_per_rank = 0; + int64_t my_rank = 0, world_size = 0; + + // Remote A shard pointers (device array of const __half*) + const __half* const* device_remote_ptrs = nullptr; + + // Input / output tensors + at::Tensor a_local_half; // [m_per_rank, k] + at::Tensor b_full_half; // [k, n] + at::Tensor c_out_half; // [m, n] + + // Staging buffers for gather-then-compute paths + at::Tensor a_gathered; // [m, k] threadload staging + at::Tensor a_gathered_multimem; // [m, k] multicast-backed + __half* multicast_ptr = nullptr; + + // Threadload semaphores (ready / done handshake) + int32_t* const* ready_sem_remote = nullptr; + int32_t* ready_sem_local = nullptr; + int32_t* const* done_sem_remote = nullptr; + int32_t* done_sem_local = nullptr; + + // Multimem semaphores (stage barrier) + int32_t* const* stage_sem_remote = nullptr; + int32_t* stage_sem_local = nullptr; + + // Baseline-only resources + c10d::Backend* nccl_backend = nullptr; + Communication* cuda_comm = nullptr; + SymMemForAllgather* cuda_handle = nullptr; + at::Tensor a_allgathered_cuda; + + // Runtime + Communicator* communicator = nullptr; + cudaStream_t stream = nullptr; +}; + +// --- Defined in .cu (kernel launchers, CUTLASS wrapper) --- +void launchNaiveRemoteRead(DistributedMatmulContext& ctx); +void launchThreadloadGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute); +void launchMultimemGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute); +at::Tensor matmulTma( + const at::Tensor& a, + const at::Tensor& b); +bool canRunCutlassCompute( + const at::Tensor& a, + const at::Tensor& b); +const char* implName(DistributedMatmulImpl impl); +bool isMulticastSupported(int64_t device_id); + +} // namespace nvfuser From 93271bf589d7ab1273ebdfcce38bdfca9c896a31 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Feb 2026 07:17:57 -0800 Subject: [PATCH 21/23] Use in-place matmulTma --- .../test_multidevice_fused_remote_matmul.cpp | 6 ++++-- .../cpp/test_multidevice_fused_remote_matmul.h | 3 ++- ...t_multidevice_fused_remote_matmul_kernel.cu | 18 ++---------------- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index d4eb8056235..e156cfdbd3d 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -350,7 +350,8 @@ double runImplementation( wu, it, ctx.stream, [&]() { launchThreadloadGather( ctx, static_cast(epoch), false); - ctx.c_out_half = matmulTma( + matmulTma( + ctx.c_out_half, ctx.a_gathered, ctx.b_full_half); ++epoch; }); @@ -370,7 +371,8 @@ double runImplementation( wu, it, ctx.stream, [&]() { launchMultimemGather( ctx, static_cast(epoch), false); - ctx.c_out_half = matmulTma( + matmulTma( + ctx.c_out_half, ctx.a_gathered_multimem, ctx.b_full_half); ++epoch; diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.h b/tests/cpp/test_multidevice_fused_remote_matmul.h index 1404d624af1..847c4cbd62e 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.h +++ b/tests/cpp/test_multidevice_fused_remote_matmul.h @@ -116,7 +116,8 @@ void launchMultimemGather( DistributedMatmulContext& ctx, int32_t epoch, bool compute); -at::Tensor matmulTma( +void matmulTma( + at::Tensor& out, const at::Tensor& a, const at::Tensor& b); bool canRunCutlassCompute( diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index d916aba7e1e..16c37c3962c 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -26,8 +26,6 @@ // - CUTLASS TMA: host-launched Hopper GEMM (in the .cpp file, not // here -- the kernel is launched with n=0 to skip in-kernel compute). // -// The CUTLASS TMA matmul wrapper (matmulTma) is also defined here, -// moved from csrc/runtime/matmul_tma.cu for self-containment. // ========================================================================= #include "test_multidevice_fused_remote_matmul.h" @@ -490,22 +488,11 @@ void launchMultimemGather( ctx.c_out_half.data_ptr())); } -at::Tensor matmulTma( +void matmulTma( + at::Tensor& out, const at::Tensor& a, const at::Tensor& b) { - NVF_CHECK(a.is_cuda() && b.is_cuda()); - NVF_CHECK(a.dim() == 2 && b.dim() == 2); - NVF_CHECK(a.size(1) == b.size(0)); - at::cuda::CUDAGuard guard{a.device()}; - auto* props = - at::cuda::getDeviceProperties(a.get_device()); - NVF_CHECK(props->major >= 9, "Requires Hopper+."); int64_t m = a.size(0), n = b.size(1), k = a.size(1); - at::Tensor out = at::empty( - {m, n}, - at::TensorOptions() - .dtype(a.scalar_type()) - .device(a.device())); cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); #if defined(NVFUSER_ENABLE_CUTLASS) @@ -518,7 +505,6 @@ at::Tensor matmulTma( #else NVF_THROW("CUTLASS support required."); #endif - return out; } bool canRunCutlassCompute( From d61aed56ea164cfe7b1e6f3b9612579c7b6407e3 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Feb 2026 08:02:32 -0800 Subject: [PATCH 22/23] renaming and clarification --- .../test_multidevice_fused_remote_matmul.cpp | 16 +++++++------- .../test_multidevice_fused_remote_matmul.h | 11 +++++----- ..._multidevice_fused_remote_matmul_kernel.cu | 21 ++++++++++++------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index e156cfdbd3d..c505fc35aeb 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -147,19 +147,19 @@ double benchmarkLoopMs( bool needsThreadloadRes(DistributedMatmulImpl impl) { using I = DistributedMatmulImpl; return impl == I::threadloadGatherScalarCompute || - impl == I::threadloadGatherCutlassCompute; + impl == I::threadloadGatherThenCutlass; } bool needsMultimemRes(DistributedMatmulImpl impl) { using I = DistributedMatmulImpl; return impl == I::multimemGatherScalarCompute || - impl == I::multimemGatherCutlassCompute; + impl == I::multimemGatherThenCutlass; } bool needsCutlass(DistributedMatmulImpl impl) { using I = DistributedMatmulImpl; - return impl == I::threadloadGatherCutlassCompute || - impl == I::multimemGatherCutlassCompute; + return impl == I::threadloadGatherThenCutlass || + impl == I::multimemGatherThenCutlass; } struct OwnedResources { @@ -344,7 +344,7 @@ double runImplementation( ++epoch; }); } - case I::threadloadGatherCutlassCompute: { + case I::threadloadGatherThenCutlass: { int64_t epoch = 0; return batchedKernelTimeMs( wu, it, ctx.stream, [&]() { @@ -365,7 +365,7 @@ double runImplementation( ++epoch; }); } - case I::multimemGatherCutlassCompute: { + case I::multimemGatherThenCutlass: { int64_t epoch = 0; return batchedKernelTimeMs( wu, it, ctx.stream, [&]() { @@ -524,8 +524,8 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::naiveRemoteRead, DistributedMatmulImpl::threadloadGatherScalarCompute, DistributedMatmulImpl::multimemGatherScalarCompute, - DistributedMatmulImpl::threadloadGatherCutlassCompute, - DistributedMatmulImpl::multimemGatherCutlassCompute), + DistributedMatmulImpl::threadloadGatherThenCutlass, + DistributedMatmulImpl::multimemGatherThenCutlass), [](const testing::TestParamInfo< DistributedMatmulImpl>& info) { return implName(info.param); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.h b/tests/cpp/test_multidevice_fused_remote_matmul.h index 847c4cbd62e..4b4fd392465 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.h +++ b/tests/cpp/test_multidevice_fused_remote_matmul.h @@ -39,8 +39,8 @@ class SymMemForAllgather; // naiveRemoteRead | 2.67 // threadloadGatherScalarCompute | 4.05 // multimemGatherScalarCompute | 3.69 -// threadloadGatherCutlassCompute | 46.5 -// multimemGatherCutlassCompute | 45.6 +// threadloadGatherThenCutlass | 50.3 +// multimemGatherThenCutlass | 50.1 // ========================================================================= enum class DistributedMatmulImpl { @@ -51,9 +51,10 @@ enum class DistributedMatmulImpl { naiveRemoteRead, threadloadGatherScalarCompute, multimemGatherScalarCompute, - // -- Two-phase: handwritten comm kernel + CUTLASS TMA matmul -- - threadloadGatherCutlassCompute, - multimemGatherCutlassCompute, + // -- Two-kernel: separate comm kernel then CUTLASS TMA matmul + // (NOT truly fused -- two kernel launches on the same stream) -- + threadloadGatherThenCutlass, + multimemGatherThenCutlass, }; enum class TimeMeasurementMode { CudaEvents, CpuClock }; diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 16c37c3962c..6a7e172084c 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -22,10 +22,17 @@ // using multimem.st (Hopper SM90+), synchronized via semaphores. // // Compute strategies: -// - Scalar: each thread accumulates one output element. -// - CUTLASS TMA: host-launched Hopper GEMM (in the .cpp file, not -// here -- the kernel is launched with n=0 to skip in-kernel compute). +// - Scalar: each thread accumulates one output element in-kernel. +// - CUTLASS TMA (two-kernel path): the gather kernel is launched +// with n=0 to skip in-kernel compute; a separate host-launched +// CUTLASS SM90 GEMM then consumes the staged A buffer. These +// are NOT truly fused -- the communication and compute are two +// distinct kernel launches on the same stream. True single-kernel +// fusion with Hopper-native WGMMA would require embedding CUTE +// MMA atoms and TMA pipelines directly inside the comm kernel. // +// The matmulTma() wrapper at the bottom of this file provides the +// CUTLASS GEMM used by the two-kernel path. // ========================================================================= #include "test_multidevice_fused_remote_matmul.h" @@ -533,10 +540,10 @@ const char* implName(DistributedMatmulImpl impl) { return "threadloadGatherScalarCompute"; case DistributedMatmulImpl::multimemGatherScalarCompute: return "multimemGatherScalarCompute"; - case DistributedMatmulImpl::threadloadGatherCutlassCompute: - return "threadloadGatherCutlassCompute"; - case DistributedMatmulImpl::multimemGatherCutlassCompute: - return "multimemGatherCutlassCompute"; + case DistributedMatmulImpl::threadloadGatherThenCutlass: + return "threadloadGatherThenCutlass"; + case DistributedMatmulImpl::multimemGatherThenCutlass: + return "multimemGatherThenCutlass"; } return "unknown"; } From 4b08cd09fcc60b6a38bc3111e28ac788ad20ce50 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Feb 2026 08:33:13 -0800 Subject: [PATCH 23/23] lint --- .../test_multidevice_fused_remote_matmul.cpp | 278 +++++-------- .../test_multidevice_fused_remote_matmul.h | 19 +- ..._multidevice_fused_remote_matmul_kernel.cu | 370 ++++++++++-------- 3 files changed, 321 insertions(+), 346 deletions(-) diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp index c505fc35aeb..cec16c44c17 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.cpp +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -63,17 +63,14 @@ double batchedKernelTimeMs( cudaEvent_t start, stop; NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventRecord(start, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); for (int64_t i = 0; i < iters; ++i) run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); float total_ms; - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); return static_cast(total_ms) / iters; @@ -103,17 +100,13 @@ double benchmarkLoopMs( for (int64_t i = 0; i < config.iters; ++i) { if (config.barrier_at_each_iteration) communicator->barrier(); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventRecord(start, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventRecord(stop, stream)); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventSynchronize(stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); float ms; - NVFUSER_CUDA_RT_SAFE_CALL( - cudaEventElapsedTime(&ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&ms, start, stop)); total_ms += ms; } NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); @@ -125,17 +118,13 @@ double benchmarkLoopMs( for (int64_t i = 0; i < config.iters; ++i) { if (config.barrier_at_each_iteration) communicator->barrier(); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaStreamSynchronize(stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); auto t0 = std::chrono::high_resolution_clock::now(); run_once(); NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaStreamSynchronize(stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); auto t1 = std::chrono::high_resolution_clock::now(); - total_ms += - std::chrono::duration(t1 - t0) - .count(); + total_ms += std::chrono::duration(t1 - t0).count(); } return total_ms / config.iters; } @@ -177,7 +166,9 @@ void initResources( DistributedMatmulImpl impl, Communicator* comm, const Team& team, - int64_t ws, int64_t m, int64_t k, + int64_t ws, + int64_t m, + int64_t k, OwnedResources& res, DistributedMatmulContext& ctx) { using I = DistributedMatmulImpl; @@ -185,13 +176,12 @@ void initResources( if (impl == I::baselineNcclAllgatherMatmul) { if (comm->isBackendAvailable(CommunicatorBackend::kNccl)) - ctx.nccl_backend = comm->getBackendForTeam( - team, CommunicatorBackend::kNccl); + ctx.nccl_backend = + comm->getBackendForTeam(team, CommunicatorBackend::kNccl); } if (impl == I::baselineCudaAllgatherMatmul) { - res.cuda_hic = - std::make_unique(); + res.cuda_hic = std::make_unique(); FusionGuard fg(res.cuda_hic.get()); auto* itv = makeContigTensor(2); auto* otv = makeContigTensor(2); @@ -199,28 +189,29 @@ void initResources( itv->setDeviceMesh(mesh); otv->setDeviceMesh(mesh); auto* cir = IrBuilder::create( - CommunicationType::Allgather, otv, itv, - team, -1, RedOpType::UNUSED, + CommunicationType::Allgather, + otv, + itv, + team, + -1, + RedOpType::UNUSED, CommunicatorBackend::kCuda); - ctx.a_allgathered_cuda = SymmetricTensor::allocate( - {m, k}, at::ScalarType::Half, dev); + ctx.a_allgathered_cuda = + SymmetricTensor::allocate({m, k}, at::ScalarType::Half, dev); res.cuda_ag_handle = - std::make_unique( - cir, ctx.a_allgathered_cuda); + std::make_unique(cir, ctx.a_allgathered_cuda); ctx.cuda_comm = cir; ctx.cuda_handle = res.cuda_ag_handle.get(); } if (needsThreadloadRes(impl)) { - ctx.a_gathered = at::empty( - {m, k}, - at::TensorOptions().dtype(at::kHalf).device(dev)); + ctx.a_gathered = + at::empty({m, k}, at::TensorOptions().dtype(at::kHalf).device(dev)); auto make_sem = [&](const char* tag) - -> std::pair> { - at::Tensor t = SymmetricTensor::allocate( - {ws, m, 4}, at::ScalarType::Int, dev); + -> std::pair> { + at::Tensor t = + SymmetricTensor::allocate({ws, m, 4}, at::ScalarType::Int, dev); t.zero_(); auto s = std::make_unique(t); s->setupRemoteHandles(tag); @@ -231,57 +222,46 @@ void initResources( res.ready_t = rt; res.ready_sym = std::move(rs); ctx.ready_sem_remote = - reinterpret_cast( - res.ready_sym->devicePeerPointers()); - ctx.ready_sem_local = reinterpret_cast( - res.ready_sym->localTensor().data_ptr()); + reinterpret_cast(res.ready_sym->devicePeerPointers()); + ctx.ready_sem_local = + reinterpret_cast(res.ready_sym->localTensor().data_ptr()); auto [dt, ds] = make_sem("fused_matmul_done"); res.done_t = dt; res.done_sym = std::move(ds); ctx.done_sem_remote = - reinterpret_cast( - res.done_sym->devicePeerPointers()); - ctx.done_sem_local = reinterpret_cast( - res.done_sym->localTensor().data_ptr()); + reinterpret_cast(res.done_sym->devicePeerPointers()); + ctx.done_sem_local = + reinterpret_cast(res.done_sym->localTensor().data_ptr()); } if (needsMultimemRes(impl)) { - ctx.a_gathered_multimem = SymmetricTensor::allocate( - {m, k}, at::ScalarType::Half, dev); - res.multimem_sym = std::make_unique( - ctx.a_gathered_multimem); - res.multimem_sym->setupMulticast( - 0, "fused_matmul_mc"); - ctx.multicast_ptr = reinterpret_cast<__half*>( - res.multimem_sym->multicastPtr()); - - res.stage_t = SymmetricTensor::allocate( - {ws, m, 4}, at::ScalarType::Int, dev); + ctx.a_gathered_multimem = + SymmetricTensor::allocate({m, k}, at::ScalarType::Half, dev); + res.multimem_sym = + std::make_unique(ctx.a_gathered_multimem); + res.multimem_sym->setupMulticast(0, "fused_matmul_mc"); + ctx.multicast_ptr = + reinterpret_cast<__half*>(res.multimem_sym->multicastPtr()); + + res.stage_t = + SymmetricTensor::allocate({ws, m, 4}, at::ScalarType::Int, dev); res.stage_t.zero_(); - res.stage_sym = - std::make_unique(res.stage_t); - res.stage_sym->setupRemoteHandles( - "fused_matmul_stage"); + res.stage_sym = std::make_unique(res.stage_t); + res.stage_sym->setupRemoteHandles("fused_matmul_stage"); ctx.stage_sem_remote = - reinterpret_cast( - res.stage_sym->devicePeerPointers()); - ctx.stage_sem_local = reinterpret_cast( - res.stage_sym->localTensor().data_ptr()); + reinterpret_cast(res.stage_sym->devicePeerPointers()); + ctx.stage_sem_local = + reinterpret_cast(res.stage_sym->localTensor().data_ptr()); } } -double reduceMaxTimeMs( - Communicator* comm, double local_ms) { +double reduceMaxTimeMs(Communicator* comm, double local_ms) { at::Tensor t = at::tensor( {static_cast(local_ms)}, - at::TensorOptions() - .dtype(at::kFloat) - .device(comm->device())); + at::TensorOptions().dtype(at::kFloat).device(comm->device())); std::vector tv = {t}; - comm->getWorld() - ->allreduce(tv, {c10d::ReduceOp::MAX}) - ->wait(); + comm->getWorld()->allreduce(tv, {c10d::ReduceOp::MAX})->wait(); return static_cast(t.item()); } @@ -301,82 +281,60 @@ double runImplementation( const int64_t it = config.iters; switch (impl) { case I::baselineNcclAllgatherMatmul: { - at::Tensor a_full = at::empty( - {ctx.m, ctx.k}, ctx.a_local_half.options()); + at::Tensor a_full = at::empty({ctx.m, ctx.k}, ctx.a_local_half.options()); auto run = [&]() { - ctx.nccl_backend - ->_allgather_base(a_full, ctx.a_local_half) - ->wait(); - at::matmul_out( - ctx.c_out_half, a_full, ctx.b_full_half); + ctx.nccl_backend->_allgather_base(a_full, ctx.a_local_half)->wait(); + at::matmul_out(ctx.c_out_half, a_full, ctx.b_full_half); }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return benchmarkLoopMs(config, ctx.communicator, ctx.stream, run); } case I::baselineCudaAllgatherMatmul: { auto run = [&]() { postWithCudaBackend( - ctx.cuda_comm, ctx.a_local_half, + ctx.cuda_comm, + ctx.a_local_half, ctx.cuda_handle, - (CUstream)ctx.stream, -1); + (CUstream)ctx.stream, + -1); waitWithCudaBackend( - ctx.cuda_comm, ctx.cuda_handle, - (CUstream)ctx.stream, -1); - at::matmul_out( - ctx.c_out_half, ctx.a_allgathered_cuda, - ctx.b_full_half); + ctx.cuda_comm, ctx.cuda_handle, (CUstream)ctx.stream, -1); + at::matmul_out(ctx.c_out_half, ctx.a_allgathered_cuda, ctx.b_full_half); }; - return benchmarkLoopMs( - config, ctx.communicator, ctx.stream, run); + return benchmarkLoopMs(config, ctx.communicator, ctx.stream, run); } case I::naiveRemoteRead: { return batchedKernelTimeMs( - wu, it, ctx.stream, [&]() { - launchNaiveRemoteRead(ctx); - }); + wu, it, ctx.stream, [&]() { launchNaiveRemoteRead(ctx); }); } case I::threadloadGatherScalarCompute: { int64_t epoch = 0; - return batchedKernelTimeMs( - wu, it, ctx.stream, [&]() { - launchThreadloadGather( - ctx, static_cast(epoch), true); - ++epoch; - }); + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchThreadloadGather(ctx, static_cast(epoch), true); + ++epoch; + }); } case I::threadloadGatherThenCutlass: { int64_t epoch = 0; - return batchedKernelTimeMs( - wu, it, ctx.stream, [&]() { - launchThreadloadGather( - ctx, static_cast(epoch), false); - matmulTma( - ctx.c_out_half, - ctx.a_gathered, ctx.b_full_half); - ++epoch; - }); + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchThreadloadGather(ctx, static_cast(epoch), false); + matmulTma(ctx.c_out_half, ctx.a_gathered, ctx.b_full_half); + ++epoch; + }); } case I::multimemGatherScalarCompute: { int64_t epoch = 0; - return batchedKernelTimeMs( - wu, it, ctx.stream, [&]() { - launchMultimemGather( - ctx, static_cast(epoch), true); - ++epoch; - }); + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchMultimemGather(ctx, static_cast(epoch), true); + ++epoch; + }); } case I::multimemGatherThenCutlass: { int64_t epoch = 0; - return batchedKernelTimeMs( - wu, it, ctx.stream, [&]() { - launchMultimemGather( - ctx, static_cast(epoch), false); - matmulTma( - ctx.c_out_half, - ctx.a_gathered_multimem, - ctx.b_full_half); - ++epoch; - }); + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchMultimemGather(ctx, static_cast(epoch), false); + matmulTma(ctx.c_out_half, ctx.a_gathered_multimem, ctx.b_full_half); + ++epoch; + }); } } NVF_ERROR(false, "Unknown implementation."); @@ -390,8 +348,7 @@ double runImplementation( class FusedRemoteMatmulTest : public MultiDeviceTest, - public testing::WithParamInterface< - DistributedMatmulImpl> { + public testing::WithParamInterface { protected: static constexpr BenchmarkConfig kConfig = { /*warmup_iters=*/8, @@ -410,8 +367,7 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmul) { const int64_t rank = communicator_->deviceId(); const auto impl = GetParam(); - if (needsMultimemRes(impl) && - !isMulticastSupported(rank)) + if (needsMultimemRes(impl) && !isMulticastSupported(rank)) GTEST_SKIP() << "Multicast unsupported."; // ---- Problem shape ---- @@ -424,19 +380,16 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmul) { // ---- Inputs ---- at::manual_seed(0); auto cpu_f = at::TensorOptions().dtype(at::kFloat); - auto gpu_h = at::TensorOptions() - .dtype(at::kHalf) - .device(communicator_->device()); + auto gpu_h = + at::TensorOptions().dtype(at::kHalf).device(communicator_->device()); at::Tensor a_full = at::randn({m, k}, cpu_f); at::Tensor b_full = at::randn({k, n}, cpu_f); - at::Tensor a_local = - a_full.slice(0, rank * mpr, (rank + 1) * mpr) - .to(gpu_h.device(), at::kHalf); + at::Tensor a_local = a_full.slice(0, rank * mpr, (rank + 1) * mpr) + .to(gpu_h.device(), at::kHalf); at::Tensor b_gpu = b_full.to(gpu_h.device(), at::kHalf); at::Tensor a_sym = SymmetricTensor::allocate( - {mpr, k}, at::ScalarType::Half, - communicator_->device()); + {mpr, k}, at::ScalarType::Half, communicator_->device()); a_sym.copy_(a_local); OwnedResources res; res.a_sym = std::make_unique(a_sym); @@ -451,67 +404,51 @@ TEST_P(FusedRemoteMatmulTest, DistributedMatmul) { ctx.my_rank = rank; ctx.world_size = ws; ctx.device_remote_ptrs = - reinterpret_cast( - res.a_sym->devicePeerPointers()); + reinterpret_cast(res.a_sym->devicePeerPointers()); ctx.a_local_half = a_local; ctx.b_full_half = b_gpu; ctx.c_out_half = at::zeros({m, n}, gpu_h); ctx.communicator = communicator_; - c10::cuda::CUDAStream test_stream = - c10::cuda::getStreamFromPool( - false, - static_cast( - communicator_->device().index())); + c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( + false, static_cast(communicator_->device().index())); c10::cuda::CUDAStreamGuard guard(test_stream); ctx.stream = test_stream.stream(); - initResources( - impl, communicator_, team, ws, m, k, res, ctx); + initResources(impl, communicator_, team, ws, m, k, res, ctx); // ---- Capability gates ---- - if (impl == - DistributedMatmulImpl:: - baselineNcclAllgatherMatmul && + if (impl == DistributedMatmulImpl::baselineNcclAllgatherMatmul && ctx.nccl_backend == nullptr) GTEST_SKIP() << "NCCL backend unavailable."; if (needsCutlass(impl)) { - at::Tensor ref = ctx.a_gathered.defined() - ? ctx.a_gathered - : ctx.a_gathered_multimem; + at::Tensor ref = + ctx.a_gathered.defined() ? ctx.a_gathered : ctx.a_gathered_multimem; if (!canRunCutlassCompute(ref, b_gpu)) GTEST_SKIP() << "CUTLASS needs Hopper SM90."; } // ---- Correctness (1 iteration, no warmup) ---- (void)runImplementation( - impl, ctx, - {0, 1, TimeMeasurementMode::CpuClock, false}); + impl, ctx, {0, 1, TimeMeasurementMode::CpuClock, false}); at::Tensor c_ref = at::matmul(a_full, b_full); - EXPECT_TRUE( - ctx.c_out_half.cpu().to(at::kFloat).allclose( - c_ref, 2e-1, 2e-1)) + EXPECT_TRUE(ctx.c_out_half.cpu().to(at::kFloat).allclose(c_ref, 2e-1, 2e-1)) << "Mismatch for " << implName(impl); // ---- Benchmark ---- communicator_->barrier(); - double local_ms = - runImplementation(impl, ctx, kConfig); + double local_ms = runImplementation(impl, ctx, kConfig); communicator_->barrier(); - double global_ms = - reduceMaxTimeMs(communicator_, local_ms); + double global_ms = reduceMaxTimeMs(communicator_, local_ms); // ---- Report ---- - double tflops = - 2.0 * m * n * k / (global_ms * 1e9); + double tflops = 2.0 * m * n * k / (global_ms * 1e9); if (rank == 0) { std::cout << "[perf] fused_remote_matmul" - << " impl=" << implName(impl) - << " M=" << m << " N=" << n << " K=" << k - << " world_size=" << ws << " : " - << global_ms << " ms/iter, " << tflops - << " TFLOP/s" << std::endl; + << " impl=" << implName(impl) << " M=" << m << " N=" << n + << " K=" << k << " world_size=" << ws << " : " << global_ms + << " ms/iter, " << tflops << " TFLOP/s" << std::endl; } } @@ -526,8 +463,7 @@ INSTANTIATE_TEST_SUITE_P( DistributedMatmulImpl::multimemGatherScalarCompute, DistributedMatmulImpl::threadloadGatherThenCutlass, DistributedMatmulImpl::multimemGatherThenCutlass), - [](const testing::TestParamInfo< - DistributedMatmulImpl>& info) { + [](const testing::TestParamInfo& info) { return implName(info.param); }); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.h b/tests/cpp/test_multidevice_fused_remote_matmul.h index 4b4fd392465..81d58661dcd 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul.h +++ b/tests/cpp/test_multidevice_fused_remote_matmul.h @@ -77,13 +77,13 @@ struct DistributedMatmulContext { const __half* const* device_remote_ptrs = nullptr; // Input / output tensors - at::Tensor a_local_half; // [m_per_rank, k] - at::Tensor b_full_half; // [k, n] - at::Tensor c_out_half; // [m, n] + at::Tensor a_local_half; // [m_per_rank, k] + at::Tensor b_full_half; // [k, n] + at::Tensor c_out_half; // [m, n] // Staging buffers for gather-then-compute paths - at::Tensor a_gathered; // [m, k] threadload staging - at::Tensor a_gathered_multimem; // [m, k] multicast-backed + at::Tensor a_gathered; // [m, k] threadload staging + at::Tensor a_gathered_multimem; // [m, k] multicast-backed __half* multicast_ptr = nullptr; // Threadload semaphores (ready / done handshake) @@ -117,13 +117,8 @@ void launchMultimemGather( DistributedMatmulContext& ctx, int32_t epoch, bool compute); -void matmulTma( - at::Tensor& out, - const at::Tensor& a, - const at::Tensor& b); -bool canRunCutlassCompute( - const at::Tensor& a, - const at::Tensor& b); +void matmulTma(at::Tensor& out, const at::Tensor& a, const at::Tensor& b); +bool canRunCutlassCompute(const at::Tensor& a, const at::Tensor& b); const char* implName(DistributedMatmulImpl impl); bool isMulticastSupported(int64_t device_id); diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu index 6a7e172084c..5450a1cd74c 100644 --- a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -72,27 +72,31 @@ namespace { // is self-contained. // ========================================================================= -bool hasValidTmaShape( - const at::Tensor& a, - const at::Tensor& b) { - if (!a.defined() || !b.defined()) return false; - if (!a.is_cuda() || !b.is_cuda()) return false; - if (a.dim() != 2 || b.dim() != 2) return false; - if (a.scalar_type() != b.scalar_type()) return false; +bool hasValidTmaShape(const at::Tensor& a, const at::Tensor& b) { + if (!a.defined() || !b.defined()) + return false; + if (!a.is_cuda() || !b.is_cuda()) + return false; + if (a.dim() != 2 || b.dim() != 2) + return false; + if (a.scalar_type() != b.scalar_type()) + return false; if (!(a.scalar_type() == at::ScalarType::Half || a.scalar_type() == at::ScalarType::BFloat16)) return false; - if (!a.is_contiguous() || !b.is_contiguous()) return false; - if (a.size(1) != b.size(0)) return false; - if (a.get_device() != b.get_device()) return false; + if (!a.is_contiguous() || !b.is_contiguous()) + return false; + if (a.size(1) != b.size(0)) + return false; + if (a.get_device() != b.get_device()) + return false; constexpr int64_t kAlign = 8; if (a.size(1) % kAlign != 0 || b.size(1) % kAlign != 0) return false; return true; } -#if defined(NVFUSER_ENABLE_CUTLASS) && \ - defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(NVFUSER_ENABLE_CUTLASS) && defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) using namespace cute; template @@ -105,14 +109,10 @@ struct TmaSm90Config { using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - static constexpr int kAA = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAB = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAC = - 128 / cutlass::sizeof_bits::value; - static constexpr int kAD = - 128 / cutlass::sizeof_bits::value; + static constexpr int kAA = 128 / cutlass::sizeof_bits::value; + static constexpr int kAB = 128 / cutlass::sizeof_bits::value; + static constexpr int kAC = 128 / cutlass::sizeof_bits::value; + static constexpr int kAD = 128 / cutlass::sizeof_bits::value; using Acc = float; using Arch = cutlass::arch::Sm90; using Op = cutlass::arch::OpClassTensorOp; @@ -120,28 +120,41 @@ struct TmaSm90Config { using Cluster = Shape<_1, _1, _1>; using SmTile = Shape<_128, _128, _64>; - using Epi = - typename cutlass::epilogue::collective::CollectiveBuilder< - Arch, Op, SmTile, Cluster, - cutlass::epilogue::collective::EpilogueTileAuto, - Acc, Acc, EC, LayoutC, kAC, ED, LayoutD, kAD, - cutlass::epilogue::collective:: - EpilogueScheduleAuto>::CollectiveOp; - - using Main = - typename cutlass::gemm::collective::CollectiveBuilder< - Arch, Op, EA, LayoutA, kAA, EB, LayoutB, kAB, - Acc, Tile, Cluster, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast( - sizeof(typename Epi::SharedStorage))>, - cutlass::gemm::collective:: - KernelScheduleAuto>::CollectiveOp; - - using Kernel = cutlass::gemm::kernel::GemmUniversal< - Shape, Main, Epi, void>; - using Gemm = - cutlass::gemm::device::GemmUniversalAdapter; + using Epi = typename cutlass::epilogue::collective::CollectiveBuilder< + Arch, + Op, + SmTile, + Cluster, + cutlass::epilogue::collective::EpilogueTileAuto, + Acc, + Acc, + EC, + LayoutC, + kAC, + ED, + LayoutD, + kAD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using Main = typename cutlass::gemm::collective::CollectiveBuilder< + Arch, + Op, + EA, + LayoutA, + kAA, + EB, + LayoutB, + kAB, + Acc, + Tile, + Cluster, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename Epi::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using Kernel = cutlass::gemm::kernel:: + GemmUniversal, Main, Epi, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using SA = typename Gemm::GemmKernel::StrideA; using SB = typename Gemm::GemmKernel::StrideB; using SC = typename Gemm::GemmKernel::StrideC; @@ -153,36 +166,36 @@ void runGemmSm90( at::Tensor& out, const at::Tensor& a, const at::Tensor& b, - int64_t m, int64_t n, int64_t k, + int64_t m, + int64_t n, + int64_t k, cudaStream_t stream) { using C = TmaSm90Config; - auto sa = cutlass::make_cute_packed_stride( - typename C::SA{}, {(int)m, (int)k, 1}); - auto sb = cutlass::make_cute_packed_stride( - typename C::SB{}, {(int)k, (int)n, 1}); - auto sc = cutlass::make_cute_packed_stride( - typename C::SC{}, {(int)m, (int)n, 1}); - auto sd = cutlass::make_cute_packed_stride( - typename C::SD{}, {(int)m, (int)n, 1}); + auto sa = + cutlass::make_cute_packed_stride(typename C::SA{}, {(int)m, (int)k, 1}); + auto sb = + cutlass::make_cute_packed_stride(typename C::SB{}, {(int)k, (int)n, 1}); + auto sc = + cutlass::make_cute_packed_stride(typename C::SC{}, {(int)m, (int)n, 1}); + auto sd = + cutlass::make_cute_packed_stride(typename C::SD{}, {(int)m, (int)n, 1}); typename C::Kernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGemm, {(int)m, (int)n, (int)k, 1}, - {static_cast(a.data_ptr()), sa, - static_cast(b.data_ptr()), sb}, - {{}, nullptr, sc, - static_cast(out.data_ptr()), sd}}; + {static_cast(a.data_ptr()), + sa, + static_cast(b.data_ptr()), + sb}, + {{}, nullptr, sc, static_cast(out.data_ptr()), sd}}; typename C::Gemm gemm; size_t ws = C::Gemm::get_workspace_size(args); auto wt = at::empty( - {(int64_t)ws}, - at::TensorOptions().dtype(at::kByte).device( - a.device())); + {(int64_t)ws}, at::TensorOptions().dtype(at::kByte).device(a.device())); NVF_CHECK( gemm.can_implement(args) == cutlass::Status::kSuccess, "CUTLASS cannot implement this GEMM."); NVF_CHECK( - gemm.initialize(args, wt.data_ptr(), stream) == - cutlass::Status::kSuccess, + gemm.initialize(args, wt.data_ptr(), stream) == cutlass::Status::kSuccess, "CUTLASS init failed."); NVF_CHECK( gemm.run(args, wt.data_ptr(), stream, nullptr, true) == @@ -194,8 +207,13 @@ void runGemmSm90( template void runGemmSm90( - at::Tensor&, const at::Tensor&, const at::Tensor&, - int64_t, int64_t, int64_t, cudaStream_t) { + at::Tensor&, + const at::Tensor&, + const at::Tensor&, + int64_t, + int64_t, + int64_t, + cudaStream_t) { NVF_THROW("CUTLASS SM90 support required for TMA matmul."); } @@ -212,54 +230,74 @@ constexpr int64_t kVecW = 4; constexpr int64_t kMaxPoll = 1LL << 26; __device__ inline void publishToAll( - int32_t* const* remote, int32_t* local, - int64_t writer, int64_t row, int64_t m, - int64_t ws, int32_t epoch) { + int32_t* const* remote, + int32_t* local, + int64_t writer, + int64_t row, + int64_t m, + int64_t ws, + int32_t epoch) { int32_t* my = local + (writer * m + row) * kVecW; - for (int64_t i = 0; i < kVecW; ++i) my[i] = epoch; + for (int64_t i = 0; i < kVecW; ++i) + my[i] = epoch; __threadfence_system(); for (int64_t p = 0; p < ws; ++p) { int32_t* d = remote[p] + (writer * m + row) * kVecW; - for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; } __threadfence_system(); } __device__ inline void publishToOne( - int32_t* target, int64_t writer, - int64_t row, int64_t m, int32_t epoch) { + int32_t* target, + int64_t writer, + int64_t row, + int64_t m, + int32_t epoch) { int32_t* d = target + (writer * m + row) * kVecW; - for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; __threadfence_system(); } __device__ inline void setLocal( - int32_t* local, int64_t writer, - int64_t row, int64_t m, int32_t epoch) { + int32_t* local, + int64_t writer, + int64_t row, + int64_t m, + int32_t epoch) { int32_t* d = local + (writer * m + row) * kVecW; - for (int64_t i = 0; i < kVecW; ++i) d[i] = epoch; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; __threadfence_system(); } __device__ inline void waitOne( - int32_t* local, int64_t row, int64_t m, - int64_t writer, int32_t epoch) { - auto* p = reinterpret_cast( - local + (writer * m + row) * kVecW); + int32_t* local, + int64_t row, + int64_t m, + int64_t writer, + int32_t epoch) { + auto* p = reinterpret_cast(local + (writer * m + row) * kVecW); int64_t s = 0; while (atomicAdd(p, 0U) < (unsigned)epoch) - if (++s > kMaxPoll) asm volatile("trap;"); + if (++s > kMaxPoll) + asm volatile("trap;"); } __device__ inline void waitAll( - int32_t* local, int64_t row, int64_t m, - int64_t ws, int32_t epoch) { + int32_t* local, + int64_t row, + int64_t m, + int64_t ws, + int32_t epoch) { for (int64_t r = 0; r < ws; ++r) { - auto* p = reinterpret_cast( - local + (r * m + row) * kVecW); + auto* p = reinterpret_cast(local + (r * m + row) * kVecW); int64_t s = 0; while (atomicAdd(p, 0U) < (unsigned)epoch) - if (++s > kMaxPoll) asm volatile("trap;"); + if (++s > kMaxPoll) + asm volatile("trap;"); } } @@ -272,19 +310,22 @@ __device__ inline void waitAll( // owner rank's shard via remote pointers -- no staging, no gather. __global__ void naiveRemoteReadKernel( const __half* const* a_shards, - const __half* b, __half* c, - int64_t m, int64_t n, int64_t k, + const __half* b, + __half* c, + int64_t m, + int64_t n, + int64_t k, int64_t m_per_rank) { int64_t row = blockIdx.y * blockDim.y + threadIdx.y; int64_t col = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= m || col >= n) return; + if (row >= m || col >= n) + return; int64_t owner = row / m_per_rank; int64_t lr = row - owner * m_per_rank; const __half* a = a_shards[owner]; float acc = 0.f; for (int64_t kk = 0; kk < k; ++kk) - acc += __half2float(a[lr * k + kk]) * - __half2float(b[kk * n + col]); + acc += __half2float(a[lr * k + kk]) * __half2float(b[kk * n + col]); c[row * n + col] = __float2half(acc); } @@ -299,37 +340,40 @@ __global__ void naiveRemoteReadKernel( __global__ void threadloadGatherKernel( const __half* const* a_shards, __half* a_gathered, - int32_t* const* ready_r, int32_t* ready_l, - int32_t* const* done_r, int32_t* done_l, - int64_t rank, int64_t ws, int32_t epoch_base, - int64_t m, int64_t n, int64_t k, + int32_t* const* ready_r, + int32_t* ready_l, + int32_t* const* done_r, + int32_t* done_l, + int64_t rank, + int64_t ws, + int32_t epoch_base, + int64_t m, + int64_t n, + int64_t k, int64_t m_per_rank, - const __half* b, __half* c) { + const __half* b, + __half* c) { const int32_t epoch = epoch_base + 1; - for (int64_t row = blockIdx.x; row < m; - row += gridDim.x) { + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { int64_t owner = row / m_per_rank; int64_t lr = row - owner * m_per_rank; const __half* a = a_shards[owner]; // --- Semaphore: owner signals readiness --- if (threadIdx.x == 0 && rank == owner) - publishToAll( - ready_r, ready_l, rank, row, m, ws, epoch); + publishToAll(ready_r, ready_l, rank, row, m, ws, epoch); __syncthreads(); if (threadIdx.x == 0 && rank != owner) waitOne(ready_l, row, m, owner, epoch); __syncthreads(); // --- Stage 1: P2P gather via thread loads --- - for (int64_t kk = threadIdx.x; kk < k; - kk += blockDim.x) + for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) a_gathered[row * k + kk] = a[lr * k + kk]; __syncthreads(); // --- Stage 2: scalar matmul (skip when n==0) --- - for (int64_t col = threadIdx.x; col < n; - col += blockDim.x) { + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { float acc = 0.f; for (int64_t kk = 0; kk < k; ++kk) acc += __half2float(a_gathered[row * k + kk]) * @@ -361,13 +405,18 @@ __global__ void threadloadGatherKernel( __global__ void multimemGatherKernel( const __half* const* a_shards, __half* a_mc, - int32_t* const* sem_r, int32_t* sem_l, - int64_t rank, int64_t ws, int32_t epoch_base, - int64_t m, int64_t n, int64_t k, + int32_t* const* sem_r, + int32_t* sem_l, + int64_t rank, + int64_t ws, + int32_t epoch_base, + int64_t m, + int64_t n, + int64_t k, int64_t m_per_rank, - const __half* b, __half* c) { - for (int64_t row = blockIdx.x; row < m; - row += gridDim.x) { + const __half* b, + __half* c) { + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { int64_t owner = row / m_per_rank; int64_t lr = row - owner * m_per_rank; const __half* a = a_shards[owner]; @@ -377,10 +426,8 @@ __global__ void multimemGatherKernel( constexpr int64_t kVec = 8; int64_t nvec = k / kVec; if (rank == owner) { - for (int64_t vi = threadIdx.x; vi < nvec; - vi += blockDim.x) { - uint4 val = reinterpret_cast( - a + lr * k)[vi]; + for (int64_t vi = threadIdx.x; vi < nvec; vi += blockDim.x) { + uint4 val = reinterpret_cast(a + lr * k)[vi]; #if __CUDA_ARCH__ >= 900 asm volatile( "multimem.st.global.v4.f32 [%0]," @@ -397,8 +444,7 @@ __global__ void multimemGatherKernel( asm volatile("trap;"); #endif } - for (int64_t kk = nvec * kVec + threadIdx.x; - kk < k; kk += blockDim.x) + for (int64_t kk = nvec * kVec + threadIdx.x; kk < k; kk += blockDim.x) arow[kk] = a[lr * k + kk]; } __syncthreads(); @@ -407,25 +453,25 @@ __global__ void multimemGatherKernel( #if __CUDA_ARCH__ >= 900 const int32_t epoch = epoch_base + 1; if (threadIdx.x == 0 && rank == owner) - publishToAll( - sem_r, sem_l, rank, row, m, ws, epoch); + publishToAll(sem_r, sem_l, rank, row, m, ws, epoch); __syncthreads(); if (threadIdx.x == 0 && rank != owner) waitOne(sem_l, row, m, owner, epoch); __syncthreads(); #else - (void)sem_r; (void)sem_l; - (void)rank; (void)ws; (void)epoch_base; + (void)sem_r; + (void)sem_l; + (void)rank; + (void)ws; + (void)epoch_base; asm volatile("trap;"); #endif // --- Stage 2: scalar matmul (skip when n==0) --- - for (int64_t col = threadIdx.x; col < n; - col += blockDim.x) { + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { float acc = 0.f; for (int64_t kk = 0; kk < k; ++kk) - acc += __half2float(arow[kk]) * - __half2float(b[kk * n + col]); + acc += __half2float(arow[kk]) * __half2float(b[kk * n + col]); c[row * n + col] = __float2half(acc); } __syncthreads(); @@ -444,15 +490,15 @@ __global__ void multimemGatherKernel( void launchNaiveRemoteRead(DistributedMatmulContext& ctx) { constexpr int64_t kB = 16; dim3 block(kB, kB); - dim3 grid( - (ctx.n + kB - 1) / kB, (ctx.m + kB - 1) / kB); + dim3 grid((ctx.n + kB - 1) / kB, (ctx.m + kB - 1) / kB); naiveRemoteReadKernel<<>>( ctx.device_remote_ptrs, - reinterpret_cast( - ctx.b_full_half.data_ptr()), - reinterpret_cast<__half*>( - ctx.c_out_half.data_ptr()), - ctx.m, ctx.n, ctx.k, ctx.m_per_rank); + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr()), + ctx.m, + ctx.n, + ctx.k, + ctx.m_per_rank); } void launchThreadloadGather( @@ -463,17 +509,20 @@ void launchThreadloadGather( dim3 grid(ctx.m); threadloadGatherKernel<<>>( ctx.device_remote_ptrs, - reinterpret_cast<__half*>( - ctx.a_gathered.data_ptr()), - ctx.ready_sem_remote, ctx.ready_sem_local, - ctx.done_sem_remote, ctx.done_sem_local, - ctx.my_rank, ctx.world_size, epoch, - ctx.m, compute ? ctx.n : 0, - ctx.k, ctx.m_per_rank, - reinterpret_cast( - ctx.b_full_half.data_ptr()), - reinterpret_cast<__half*>( - ctx.c_out_half.data_ptr())); + reinterpret_cast<__half*>(ctx.a_gathered.data_ptr()), + ctx.ready_sem_remote, + ctx.ready_sem_local, + ctx.done_sem_remote, + ctx.done_sem_local, + ctx.my_rank, + ctx.world_size, + epoch, + ctx.m, + compute ? ctx.n : 0, + ctx.k, + ctx.m_per_rank, + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr())); } void launchMultimemGather( @@ -485,45 +534,40 @@ void launchMultimemGather( multimemGatherKernel<<>>( ctx.device_remote_ptrs, ctx.multicast_ptr, - ctx.stage_sem_remote, ctx.stage_sem_local, - ctx.my_rank, ctx.world_size, epoch, - ctx.m, compute ? ctx.n : 0, - ctx.k, ctx.m_per_rank, - reinterpret_cast( - ctx.b_full_half.data_ptr()), - reinterpret_cast<__half*>( - ctx.c_out_half.data_ptr())); + ctx.stage_sem_remote, + ctx.stage_sem_local, + ctx.my_rank, + ctx.world_size, + epoch, + ctx.m, + compute ? ctx.n : 0, + ctx.k, + ctx.m_per_rank, + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr())); } -void matmulTma( - at::Tensor& out, - const at::Tensor& a, - const at::Tensor& b) { +void matmulTma(at::Tensor& out, const at::Tensor& a, const at::Tensor& b) { int64_t m = a.size(0), n = b.size(1), k = a.size(1); - cudaStream_t stream = - at::cuda::getCurrentCUDAStream(a.get_device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); #if defined(NVFUSER_ENABLE_CUTLASS) if (a.scalar_type() == at::ScalarType::Half) - runGemmSm90( - out, a, b, m, n, k, stream); + runGemmSm90(out, a, b, m, n, k, stream); else - runGemmSm90( - out, a, b, m, n, k, stream); + runGemmSm90(out, a, b, m, n, k, stream); #else NVF_THROW("CUTLASS support required."); #endif } -bool canRunCutlassCompute( - const at::Tensor& a, - const at::Tensor& b) { - if (!hasValidTmaShape(a, b)) return false; +bool canRunCutlassCompute(const at::Tensor& a, const at::Tensor& b) { + if (!hasValidTmaShape(a, b)) + return false; #if !defined(NVFUSER_ENABLE_CUTLASS) || \ !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) return false; #else - auto* props = - at::cuda::getDeviceProperties(a.get_device()); + auto* props = at::cuda::getDeviceProperties(a.get_device()); return props->major == 9 && props->minor == 0; #endif }