diff --git a/CMakeLists.txt b/CMakeLists.txt index f5dad5c90d5..4e7accfb4ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -399,19 +399,20 @@ endif() # "private" (not installed) static library. add_library(codegen_internal OBJECT ${NVFUSER_SRCS}) + if(NOT MSVC) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") target_compile_options(codegen_internal PRIVATE - -Wall -Wno-unused-function -Werror + $<$:-Wall -Wno-unused-function -Werror # These warnings are not treated as errors because of gcc 12.2 used in # manylinux image. consider enable this when we upgrade. # linking comment: # https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266 - -Wno-error=restrict -Wno-error=stringop-overflow -Wno-error=maybe-uninitialized) + -Wno-error=restrict -Wno-error=stringop-overflow -Wno-error=maybe-uninitialized>) else() target_compile_options(codegen_internal PRIVATE - -Wall -Wno-unused-function -Werror) + $<$:-Wall -Wno-unused-function -Werror>) endif() endif() @@ -423,6 +424,9 @@ if (NVMMH_FOUND) endif() target_include_directories(codegen_internal SYSTEM PUBLIC ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ${NVFUSER_THIRD_PARTY_DIR}/cutlass/include + ${NVFUSER_THIRD_PARTY_DIR}/cutlass/tools/util/include + /usr/local/cuda/include/cccl PRIVATE ${CUDA_INCLUDE_DIRS} ) @@ -919,7 +923,7 @@ function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK) if(NOT MSVC) target_compile_options(${TEST_NAME} PRIVATE - -Wall -Wno-unused-function -Werror + $<$:-Wall -Wno-unused-function -Werror> ) endif() endfunction() @@ -1019,6 +1023,8 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_fused_remote_matmul.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp diff --git a/csrc/multidevice/symmetric_tensor.cpp b/csrc/multidevice/symmetric_tensor.cpp index 902ec5a3a32..a495b7e82dc 100644 --- a/csrc/multidevice/symmetric_tensor.cpp +++ b/csrc/multidevice/symmetric_tensor.cpp @@ -253,6 +253,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) } SymmetricTensor::~SymmetricTensor() { + if (device_peer_ptrs_ != nullptr) { + cudaFree(device_peer_ptrs_); + device_peer_ptrs_ = nullptr; + } + #if (CUDA_VERSION >= 13000) if (is_multicast_setup_) { if (mc_base_ptr_) { @@ -389,6 +394,24 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { .device(at::kCUDA, rank)); } +void** SymmetricTensor::devicePeerPointers() const { + NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup"); + if (device_peer_ptrs_ == nullptr) { + std::vector host_peer_ptrs(world_size_); + for (int64_t rank = 0; rank < world_size_; ++rank) { + host_peer_ptrs[rank] = reinterpret_cast(remote_ptrs_[rank]); + } + NVFUSER_CUDA_RT_SAFE_CALL( + cudaMalloc(&device_peer_ptrs_, world_size_ * sizeof(void*))); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + device_peer_ptrs_, + host_peer_ptrs.data(), + world_size_ * sizeof(void*), + cudaMemcpyHostToDevice)); + } + return device_peer_ptrs_; +} + void* SymmetricTensor::multicastPtr() const { NVF_CHECK(is_multicast_setup_, "Multicast not setup"); return mc_ptr_; diff --git a/csrc/multidevice/symmetric_tensor.h b/csrc/multidevice/symmetric_tensor.h index 5608153e0ce..46ac7da1c7c 100644 --- a/csrc/multidevice/symmetric_tensor.h +++ b/csrc/multidevice/symmetric_tensor.h @@ -51,6 +51,8 @@ class SymmetricTensor { // Setup remote access (lazy, init-once) void setupRemoteHandles(const std::string& tag = ""); at::Tensor remoteTensor(int64_t rank) const; + // Returns a device pointer table of peer pointers (void** on device). + void** devicePeerPointers() const; // Setup multicast (CUDA 13.0+, init-once) void setupMulticast(int64_t exporter_rank, const std::string& tag = ""); @@ -79,6 +81,7 @@ class SymmetricTensor { int peer_fd_{-1}; bool is_contiguous_view_setup_ = false; at::Tensor contiguous_view_; + mutable void** device_peer_ptrs_ = nullptr; }; } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.cpp b/tests/cpp/test_multidevice_fused_remote_matmul.cpp new file mode 100644 index 00000000000..cec16c44c17 --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul.cpp @@ -0,0 +1,470 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// +// ========================================================================= +// Distributed Matmul Benchmark -- Test Harness +// +// Measures allgather + matmul throughput for C = A * B where A is +// row-sharded on M across ranks and B is replicated. Compares +// baseline (NCCL / CUDA P2P) and fused kernel implementations. +// +// See test_multidevice_fused_remote_matmul.h for the performance +// summary and implementation descriptions. +// ========================================================================= + +#include "test_multidevice_fused_remote_matmul.h" + +#include + +#include +#include + +#include +#include +#include + +#include "fusion.h" +#include "host_ir/container.h" +#include "ir/builder.h" +#include "multidevice/communication.h" +#include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/ipc_handle.h" +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +namespace { + +// ========================================================================= +// Timing helpers +// ========================================================================= + +// Batched GPU timing: one cuda-event pair around all iterations. +// No per-iteration host sync, matching the original kernel timing +// methodology. Avoids ~15us cudaStreamSynchronize overhead that +// dominates sub-100us kernels like CUTLASS TMA matmuls. +template +double batchedKernelTimeMs( + int64_t warmup_iters, + int64_t iters, + cudaStream_t stream, + Fn&& run_once) { + for (int64_t i = 0; i < warmup_iters; ++i) + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + cudaEvent_t start, stop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + for (int64_t i = 0; i < iters; ++i) + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + float total_ms; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&total_ms, start, stop)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / iters; +} + +// Per-iteration timing for baselines with host-blocking waits. +template +double benchmarkLoopMs( + const BenchmarkConfig& config, + Communicator* communicator, + cudaStream_t stream, + Fn&& run_once) { + NVF_CHECK(config.iters > 0, "iters must be > 0"); + for (int64_t i = 0; i < config.warmup_iters; ++i) { + if (config.barrier_at_each_iteration) + communicator->barrier(); + run_once(); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + + if (config.time_mode == TimeMeasurementMode::CudaEvents) { + cudaEvent_t start, stop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&stop)); + float total_ms = 0.f; + for (int64_t i = 0; i < config.iters; ++i) { + if (config.barrier_at_each_iteration) + communicator->barrier(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start, stream)); + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(stop, stream)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(stop)); + float ms; + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventElapsedTime(&ms, start, stop)); + total_ms += ms; + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(stop)); + return static_cast(total_ms) / config.iters; + } + + double total_ms = 0.0; + for (int64_t i = 0; i < config.iters; ++i) { + if (config.barrier_at_each_iteration) + communicator->barrier(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + auto t0 = std::chrono::high_resolution_clock::now(); + run_once(); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetLastError()); + NVFUSER_CUDA_RT_SAFE_CALL(cudaStreamSynchronize(stream)); + auto t1 = std::chrono::high_resolution_clock::now(); + total_ms += std::chrono::duration(t1 - t0).count(); + } + return total_ms / config.iters; +} + +// ========================================================================= +// Resource helpers +// ========================================================================= + +bool needsThreadloadRes(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::threadloadGatherScalarCompute || + impl == I::threadloadGatherThenCutlass; +} + +bool needsMultimemRes(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::multimemGatherScalarCompute || + impl == I::multimemGatherThenCutlass; +} + +bool needsCutlass(DistributedMatmulImpl impl) { + using I = DistributedMatmulImpl; + return impl == I::threadloadGatherThenCutlass || + impl == I::multimemGatherThenCutlass; +} + +struct OwnedResources { + std::unique_ptr a_sym; + std::unique_ptr ready_sym; + std::unique_ptr done_sym; + std::unique_ptr stage_sym; + std::unique_ptr multimem_sym; + std::unique_ptr cuda_hic; + std::unique_ptr cuda_ag_handle; + at::Tensor ready_t, done_t, stage_t; +}; + +void initResources( + DistributedMatmulImpl impl, + Communicator* comm, + const Team& team, + int64_t ws, + int64_t m, + int64_t k, + OwnedResources& res, + DistributedMatmulContext& ctx) { + using I = DistributedMatmulImpl; + auto dev = comm->device(); + + if (impl == I::baselineNcclAllgatherMatmul) { + if (comm->isBackendAvailable(CommunicatorBackend::kNccl)) + ctx.nccl_backend = + comm->getBackendForTeam(team, CommunicatorBackend::kNccl); + } + + if (impl == I::baselineCudaAllgatherMatmul) { + res.cuda_hic = std::make_unique(); + FusionGuard fg(res.cuda_hic.get()); + auto* itv = makeContigTensor(2); + auto* otv = makeContigTensor(2); + auto mesh = DeviceMesh::createForNumDevices(ws); + itv->setDeviceMesh(mesh); + otv->setDeviceMesh(mesh); + auto* cir = IrBuilder::create( + CommunicationType::Allgather, + otv, + itv, + team, + -1, + RedOpType::UNUSED, + CommunicatorBackend::kCuda); + ctx.a_allgathered_cuda = + SymmetricTensor::allocate({m, k}, at::ScalarType::Half, dev); + res.cuda_ag_handle = + std::make_unique(cir, ctx.a_allgathered_cuda); + ctx.cuda_comm = cir; + ctx.cuda_handle = res.cuda_ag_handle.get(); + } + + if (needsThreadloadRes(impl)) { + ctx.a_gathered = + at::empty({m, k}, at::TensorOptions().dtype(at::kHalf).device(dev)); + + auto make_sem = [&](const char* tag) + -> std::pair> { + at::Tensor t = + SymmetricTensor::allocate({ws, m, 4}, at::ScalarType::Int, dev); + t.zero_(); + auto s = std::make_unique(t); + s->setupRemoteHandles(tag); + return {t, std::move(s)}; + }; + + auto [rt, rs] = make_sem("fused_matmul_ready"); + res.ready_t = rt; + res.ready_sym = std::move(rs); + ctx.ready_sem_remote = + reinterpret_cast(res.ready_sym->devicePeerPointers()); + ctx.ready_sem_local = + reinterpret_cast(res.ready_sym->localTensor().data_ptr()); + + auto [dt, ds] = make_sem("fused_matmul_done"); + res.done_t = dt; + res.done_sym = std::move(ds); + ctx.done_sem_remote = + reinterpret_cast(res.done_sym->devicePeerPointers()); + ctx.done_sem_local = + reinterpret_cast(res.done_sym->localTensor().data_ptr()); + } + + if (needsMultimemRes(impl)) { + ctx.a_gathered_multimem = + SymmetricTensor::allocate({m, k}, at::ScalarType::Half, dev); + res.multimem_sym = + std::make_unique(ctx.a_gathered_multimem); + res.multimem_sym->setupMulticast(0, "fused_matmul_mc"); + ctx.multicast_ptr = + reinterpret_cast<__half*>(res.multimem_sym->multicastPtr()); + + res.stage_t = + SymmetricTensor::allocate({ws, m, 4}, at::ScalarType::Int, dev); + res.stage_t.zero_(); + res.stage_sym = std::make_unique(res.stage_t); + res.stage_sym->setupRemoteHandles("fused_matmul_stage"); + ctx.stage_sem_remote = + reinterpret_cast(res.stage_sym->devicePeerPointers()); + ctx.stage_sem_local = + reinterpret_cast(res.stage_sym->localTensor().data_ptr()); + } +} + +double reduceMaxTimeMs(Communicator* comm, double local_ms) { + at::Tensor t = at::tensor( + {static_cast(local_ms)}, + at::TensorOptions().dtype(at::kFloat).device(comm->device())); + std::vector tv = {t}; + comm->getWorld()->allreduce(tv, {c10d::ReduceOp::MAX})->wait(); + return static_cast(t.item()); +} + +// ========================================================================= +// Implementation dispatcher +// +// Each case builds a run_once lambda and wraps it with +// benchmarkLoopMs. Kernel launchers live in the .cu file. +// ========================================================================= + +double runImplementation( + DistributedMatmulImpl impl, + DistributedMatmulContext& ctx, + const BenchmarkConfig& config) { + using I = DistributedMatmulImpl; + const int64_t wu = config.warmup_iters; + const int64_t it = config.iters; + switch (impl) { + case I::baselineNcclAllgatherMatmul: { + at::Tensor a_full = at::empty({ctx.m, ctx.k}, ctx.a_local_half.options()); + auto run = [&]() { + ctx.nccl_backend->_allgather_base(a_full, ctx.a_local_half)->wait(); + at::matmul_out(ctx.c_out_half, a_full, ctx.b_full_half); + }; + return benchmarkLoopMs(config, ctx.communicator, ctx.stream, run); + } + case I::baselineCudaAllgatherMatmul: { + auto run = [&]() { + postWithCudaBackend( + ctx.cuda_comm, + ctx.a_local_half, + ctx.cuda_handle, + (CUstream)ctx.stream, + -1); + waitWithCudaBackend( + ctx.cuda_comm, ctx.cuda_handle, (CUstream)ctx.stream, -1); + at::matmul_out(ctx.c_out_half, ctx.a_allgathered_cuda, ctx.b_full_half); + }; + return benchmarkLoopMs(config, ctx.communicator, ctx.stream, run); + } + case I::naiveRemoteRead: { + return batchedKernelTimeMs( + wu, it, ctx.stream, [&]() { launchNaiveRemoteRead(ctx); }); + } + case I::threadloadGatherScalarCompute: { + int64_t epoch = 0; + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchThreadloadGather(ctx, static_cast(epoch), true); + ++epoch; + }); + } + case I::threadloadGatherThenCutlass: { + int64_t epoch = 0; + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchThreadloadGather(ctx, static_cast(epoch), false); + matmulTma(ctx.c_out_half, ctx.a_gathered, ctx.b_full_half); + ++epoch; + }); + } + case I::multimemGatherScalarCompute: { + int64_t epoch = 0; + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchMultimemGather(ctx, static_cast(epoch), true); + ++epoch; + }); + } + case I::multimemGatherThenCutlass: { + int64_t epoch = 0; + return batchedKernelTimeMs(wu, it, ctx.stream, [&]() { + launchMultimemGather(ctx, static_cast(epoch), false); + matmulTma(ctx.c_out_half, ctx.a_gathered_multimem, ctx.b_full_half); + ++epoch; + }); + } + } + NVF_ERROR(false, "Unknown implementation."); +} + +} // anonymous namespace + +// ========================================================================= +// Test fixture +// ========================================================================= + +class FusedRemoteMatmulTest + : public MultiDeviceTest, + public testing::WithParamInterface { + protected: + static constexpr BenchmarkConfig kConfig = { + /*warmup_iters=*/8, + /*iters=*/30, + /*time_mode=*/TimeMeasurementMode::CpuClock, + /*barrier_at_each_iteration=*/false}; +}; + +TEST_P(FusedRemoteMatmulTest, DistributedMatmul) { + if (!communicator_->is_available()) + GTEST_SKIP() << "Communicator unavailable."; + if (communicator_->size() == 1) + GTEST_SKIP() << "Needs >= 2 devices."; + + const int64_t ws = communicator_->size(); + const int64_t rank = communicator_->deviceId(); + const auto impl = GetParam(); + + if (needsMultimemRes(impl) && !isMulticastSupported(rank)) + GTEST_SKIP() << "Multicast unsupported."; + + // ---- Problem shape ---- + constexpr int64_t m = 1024, k = 1024, n = 1024; + NVF_ERROR(m % ws == 0); + const int64_t mpr = m / ws; + Team team(ws); + std::iota(team.begin(), team.end(), 0); + + // ---- Inputs ---- + at::manual_seed(0); + auto cpu_f = at::TensorOptions().dtype(at::kFloat); + auto gpu_h = + at::TensorOptions().dtype(at::kHalf).device(communicator_->device()); + at::Tensor a_full = at::randn({m, k}, cpu_f); + at::Tensor b_full = at::randn({k, n}, cpu_f); + at::Tensor a_local = a_full.slice(0, rank * mpr, (rank + 1) * mpr) + .to(gpu_h.device(), at::kHalf); + at::Tensor b_gpu = b_full.to(gpu_h.device(), at::kHalf); + + at::Tensor a_sym = SymmetricTensor::allocate( + {mpr, k}, at::ScalarType::Half, communicator_->device()); + a_sym.copy_(a_local); + OwnedResources res; + res.a_sym = std::make_unique(a_sym); + res.a_sym->setupRemoteHandles("fused_matmul_a"); + + // ---- Build context ---- + DistributedMatmulContext ctx; + ctx.m = m; + ctx.n = n; + ctx.k = k; + ctx.m_per_rank = mpr; + ctx.my_rank = rank; + ctx.world_size = ws; + ctx.device_remote_ptrs = + reinterpret_cast(res.a_sym->devicePeerPointers()); + ctx.a_local_half = a_local; + ctx.b_full_half = b_gpu; + ctx.c_out_half = at::zeros({m, n}, gpu_h); + ctx.communicator = communicator_; + + c10::cuda::CUDAStream test_stream = c10::cuda::getStreamFromPool( + false, static_cast(communicator_->device().index())); + c10::cuda::CUDAStreamGuard guard(test_stream); + ctx.stream = test_stream.stream(); + + initResources(impl, communicator_, team, ws, m, k, res, ctx); + + // ---- Capability gates ---- + if (impl == DistributedMatmulImpl::baselineNcclAllgatherMatmul && + ctx.nccl_backend == nullptr) + GTEST_SKIP() << "NCCL backend unavailable."; + + if (needsCutlass(impl)) { + at::Tensor ref = + ctx.a_gathered.defined() ? ctx.a_gathered : ctx.a_gathered_multimem; + if (!canRunCutlassCompute(ref, b_gpu)) + GTEST_SKIP() << "CUTLASS needs Hopper SM90."; + } + + // ---- Correctness (1 iteration, no warmup) ---- + (void)runImplementation( + impl, ctx, {0, 1, TimeMeasurementMode::CpuClock, false}); + at::Tensor c_ref = at::matmul(a_full, b_full); + EXPECT_TRUE(ctx.c_out_half.cpu().to(at::kFloat).allclose(c_ref, 2e-1, 2e-1)) + << "Mismatch for " << implName(impl); + + // ---- Benchmark ---- + communicator_->barrier(); + double local_ms = runImplementation(impl, ctx, kConfig); + communicator_->barrier(); + double global_ms = reduceMaxTimeMs(communicator_, local_ms); + + // ---- Report ---- + double tflops = 2.0 * m * n * k / (global_ms * 1e9); + if (rank == 0) { + std::cout << "[perf] fused_remote_matmul" + << " impl=" << implName(impl) << " M=" << m << " N=" << n + << " K=" << k << " world_size=" << ws << " : " << global_ms + << " ms/iter, " << tflops << " TFLOP/s" << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P( + , + FusedRemoteMatmulTest, + testing::Values( + DistributedMatmulImpl::baselineNcclAllgatherMatmul, + DistributedMatmulImpl::baselineCudaAllgatherMatmul, + DistributedMatmulImpl::naiveRemoteRead, + DistributedMatmulImpl::threadloadGatherScalarCompute, + DistributedMatmulImpl::multimemGatherScalarCompute, + DistributedMatmulImpl::threadloadGatherThenCutlass, + DistributedMatmulImpl::multimemGatherThenCutlass), + [](const testing::TestParamInfo& info) { + return implName(info.param); + }); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul.h b/tests/cpp/test_multidevice_fused_remote_matmul.h new file mode 100644 index 00000000000..81d58661dcd --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul.h @@ -0,0 +1,125 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +#include + +namespace c10d { +class Backend; +} + +namespace nvfuser { + +class Communicator; +class Communication; +class SymMemForAllgather; + +// ========================================================================= +// Distributed Matmul Benchmark -- shared types +// +// Computes C[M,N] = A[M,K] x B[K,N] where A is row-sharded across ranks +// on axis M and B is replicated. Each implementation varies the +// communication strategy (how A shards are gathered) and the compute +// strategy (how the matmul is performed). +// +// Performance on 8xH100 DGX (M=N=K=1024, half precision): +// +// Implementation | TFLOP/s +// ---------------------------------------+-------- +// baselineNcclAllgatherMatmul | 39.9 +// baselineCudaAllgatherMatmul | 24.1 +// naiveRemoteRead | 2.67 +// threadloadGatherScalarCompute | 4.05 +// multimemGatherScalarCompute | 3.69 +// threadloadGatherThenCutlass | 50.3 +// multimemGatherThenCutlass | 50.1 +// ========================================================================= + +enum class DistributedMatmulImpl { + // -- Baselines: separate allgather then PyTorch eager matmul -- + baselineNcclAllgatherMatmul, + baselineCudaAllgatherMatmul, + // -- Fused kernels: comm + scalar matmul in one kernel -- + naiveRemoteRead, + threadloadGatherScalarCompute, + multimemGatherScalarCompute, + // -- Two-kernel: separate comm kernel then CUTLASS TMA matmul + // (NOT truly fused -- two kernel launches on the same stream) -- + threadloadGatherThenCutlass, + multimemGatherThenCutlass, +}; + +enum class TimeMeasurementMode { CudaEvents, CpuClock }; + +struct BenchmarkConfig { + int64_t warmup_iters; + int64_t iters; + TimeMeasurementMode time_mode; + bool barrier_at_each_iteration; +}; + +// All data any implementation may need. Unused fields are +// null/undefined for a given implementation. +struct DistributedMatmulContext { + // Problem dimensions + int64_t m = 0, n = 0, k = 0, m_per_rank = 0; + int64_t my_rank = 0, world_size = 0; + + // Remote A shard pointers (device array of const __half*) + const __half* const* device_remote_ptrs = nullptr; + + // Input / output tensors + at::Tensor a_local_half; // [m_per_rank, k] + at::Tensor b_full_half; // [k, n] + at::Tensor c_out_half; // [m, n] + + // Staging buffers for gather-then-compute paths + at::Tensor a_gathered; // [m, k] threadload staging + at::Tensor a_gathered_multimem; // [m, k] multicast-backed + __half* multicast_ptr = nullptr; + + // Threadload semaphores (ready / done handshake) + int32_t* const* ready_sem_remote = nullptr; + int32_t* ready_sem_local = nullptr; + int32_t* const* done_sem_remote = nullptr; + int32_t* done_sem_local = nullptr; + + // Multimem semaphores (stage barrier) + int32_t* const* stage_sem_remote = nullptr; + int32_t* stage_sem_local = nullptr; + + // Baseline-only resources + c10d::Backend* nccl_backend = nullptr; + Communication* cuda_comm = nullptr; + SymMemForAllgather* cuda_handle = nullptr; + at::Tensor a_allgathered_cuda; + + // Runtime + Communicator* communicator = nullptr; + cudaStream_t stream = nullptr; +}; + +// --- Defined in .cu (kernel launchers, CUTLASS wrapper) --- +void launchNaiveRemoteRead(DistributedMatmulContext& ctx); +void launchThreadloadGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute); +void launchMultimemGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute); +void matmulTma(at::Tensor& out, const at::Tensor& a, const at::Tensor& b); +bool canRunCutlassCompute(const at::Tensor& a, const at::Tensor& b); +const char* implName(DistributedMatmulImpl impl); +bool isMulticastSupported(int64_t device_id); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu new file mode 100644 index 00000000000..5450a1cd74c --- /dev/null +++ b/tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu @@ -0,0 +1,604 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// +// ========================================================================= +// Distributed Matmul Kernels +// +// This file contains CUDA kernels and host-side launcher functions for +// the distributed matmul benchmark. Each kernel combines a +// communication strategy with a compute strategy: +// +// Communication strategies: +// - Naive remote read: each thread reads A directly from the owner +// rank via remote pointers. +// - Threadload gather: cooperative thread loads stage A rows into a +// local buffer, synchronized via ready/done semaphores. +// - Multimem gather: owner rank writes A rows to a multicast buffer +// using multimem.st (Hopper SM90+), synchronized via semaphores. +// +// Compute strategies: +// - Scalar: each thread accumulates one output element in-kernel. +// - CUTLASS TMA (two-kernel path): the gather kernel is launched +// with n=0 to skip in-kernel compute; a separate host-launched +// CUTLASS SM90 GEMM then consumes the staged A buffer. These +// are NOT truly fused -- the communication and compute are two +// distinct kernel launches on the same stream. True single-kernel +// fusion with Hopper-native WGMMA would require embedding CUTE +// MMA atoms and TMA pipelines directly inside the comm kernel. +// +// The matmulTma() wrapper at the bottom of this file provides the +// CUTLASS GEMM used by the two-kernel path. +// ========================================================================= + +#include "test_multidevice_fused_remote_matmul.h" + +#include + +#include +#include +#include + +#include + +// CUTLASS TMA matmul (Hopper SM90) +#if defined(NVFUSER_ENABLE_CUTLASS) +#if !defined(__CUDACC_VER_MAJOR__) +#define __CUDACC_VER_MAJOR__ 13 +#define __CUDACC_VER_MINOR__ 0 +#endif +#include "cutlass/arch/config.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +#endif + +namespace nvfuser { + +namespace { + +// ========================================================================= +// Section 1: CUTLASS TMA matmul wrapper +// +// Provides matmulTma() -- a Hopper SM90 GEMM using CUTLASS 3.x with +// TMA loads. Moved from csrc/runtime/matmul_tma.cu so the benchmark +// is self-contained. +// ========================================================================= + +bool hasValidTmaShape(const at::Tensor& a, const at::Tensor& b) { + if (!a.defined() || !b.defined()) + return false; + if (!a.is_cuda() || !b.is_cuda()) + return false; + if (a.dim() != 2 || b.dim() != 2) + return false; + if (a.scalar_type() != b.scalar_type()) + return false; + if (!(a.scalar_type() == at::ScalarType::Half || + a.scalar_type() == at::ScalarType::BFloat16)) + return false; + if (!a.is_contiguous() || !b.is_contiguous()) + return false; + if (a.size(1) != b.size(0)) + return false; + if (a.get_device() != b.get_device()) + return false; + constexpr int64_t kAlign = 8; + if (a.size(1) % kAlign != 0 || b.size(1) % kAlign != 0) + return false; + return true; +} + +#if defined(NVFUSER_ENABLE_CUTLASS) && defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +using namespace cute; + +template +struct TmaSm90Config { + using EA = ElementT; + using EB = ElementT; + using EC = ElementT; + using ED = ElementT; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int kAA = 128 / cutlass::sizeof_bits::value; + static constexpr int kAB = 128 / cutlass::sizeof_bits::value; + static constexpr int kAC = 128 / cutlass::sizeof_bits::value; + static constexpr int kAD = 128 / cutlass::sizeof_bits::value; + using Acc = float; + using Arch = cutlass::arch::Sm90; + using Op = cutlass::arch::OpClassTensorOp; + using Tile = Shape<_128, _128, _64>; + using Cluster = Shape<_1, _1, _1>; + using SmTile = Shape<_128, _128, _64>; + + using Epi = typename cutlass::epilogue::collective::CollectiveBuilder< + Arch, + Op, + SmTile, + Cluster, + cutlass::epilogue::collective::EpilogueTileAuto, + Acc, + Acc, + EC, + LayoutC, + kAC, + ED, + LayoutD, + kAD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using Main = typename cutlass::gemm::collective::CollectiveBuilder< + Arch, + Op, + EA, + LayoutA, + kAA, + EB, + LayoutB, + kAB, + Acc, + Tile, + Cluster, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename Epi::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using Kernel = cutlass::gemm::kernel:: + GemmUniversal, Main, Epi, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using SA = typename Gemm::GemmKernel::StrideA; + using SB = typename Gemm::GemmKernel::StrideB; + using SC = typename Gemm::GemmKernel::StrideC; + using SD = typename Gemm::GemmKernel::StrideD; +}; + +template +void runGemmSm90( + at::Tensor& out, + const at::Tensor& a, + const at::Tensor& b, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + using C = TmaSm90Config; + auto sa = + cutlass::make_cute_packed_stride(typename C::SA{}, {(int)m, (int)k, 1}); + auto sb = + cutlass::make_cute_packed_stride(typename C::SB{}, {(int)k, (int)n, 1}); + auto sc = + cutlass::make_cute_packed_stride(typename C::SC{}, {(int)m, (int)n, 1}); + auto sd = + cutlass::make_cute_packed_stride(typename C::SD{}, {(int)m, (int)n, 1}); + typename C::Kernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + {(int)m, (int)n, (int)k, 1}, + {static_cast(a.data_ptr()), + sa, + static_cast(b.data_ptr()), + sb}, + {{}, nullptr, sc, static_cast(out.data_ptr()), sd}}; + typename C::Gemm gemm; + size_t ws = C::Gemm::get_workspace_size(args); + auto wt = at::empty( + {(int64_t)ws}, at::TensorOptions().dtype(at::kByte).device(a.device())); + NVF_CHECK( + gemm.can_implement(args) == cutlass::Status::kSuccess, + "CUTLASS cannot implement this GEMM."); + NVF_CHECK( + gemm.initialize(args, wt.data_ptr(), stream) == cutlass::Status::kSuccess, + "CUTLASS init failed."); + NVF_CHECK( + gemm.run(args, wt.data_ptr(), stream, nullptr, true) == + cutlass::Status::kSuccess, + "CUTLASS run failed."); +} + +#else + +template +void runGemmSm90( + at::Tensor&, + const at::Tensor&, + const at::Tensor&, + int64_t, + int64_t, + int64_t, + cudaStream_t) { + NVF_THROW("CUTLASS SM90 support required for TMA matmul."); +} + +#endif + +// ========================================================================= +// Section 2: Remote semaphore helpers +// +// Device-side helpers for inter-rank synchronization. Each semaphore +// row stores kVecW int32 epochs. Owner publishes; readers poll. +// ========================================================================= + +constexpr int64_t kVecW = 4; +constexpr int64_t kMaxPoll = 1LL << 26; + +__device__ inline void publishToAll( + int32_t* const* remote, + int32_t* local, + int64_t writer, + int64_t row, + int64_t m, + int64_t ws, + int32_t epoch) { + int32_t* my = local + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) + my[i] = epoch; + __threadfence_system(); + for (int64_t p = 0; p < ws; ++p) { + int32_t* d = remote[p] + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; + } + __threadfence_system(); +} + +__device__ inline void publishToOne( + int32_t* target, + int64_t writer, + int64_t row, + int64_t m, + int32_t epoch) { + int32_t* d = target + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; + __threadfence_system(); +} + +__device__ inline void setLocal( + int32_t* local, + int64_t writer, + int64_t row, + int64_t m, + int32_t epoch) { + int32_t* d = local + (writer * m + row) * kVecW; + for (int64_t i = 0; i < kVecW; ++i) + d[i] = epoch; + __threadfence_system(); +} + +__device__ inline void waitOne( + int32_t* local, + int64_t row, + int64_t m, + int64_t writer, + int32_t epoch) { + auto* p = reinterpret_cast(local + (writer * m + row) * kVecW); + int64_t s = 0; + while (atomicAdd(p, 0U) < (unsigned)epoch) + if (++s > kMaxPoll) + asm volatile("trap;"); +} + +__device__ inline void waitAll( + int32_t* local, + int64_t row, + int64_t m, + int64_t ws, + int32_t epoch) { + for (int64_t r = 0; r < ws; ++r) { + auto* p = reinterpret_cast(local + (r * m + row) * kVecW); + int64_t s = 0; + while (atomicAdd(p, 0U) < (unsigned)epoch) + if (++s > kMaxPoll) + asm volatile("trap;"); + } +} + +// ========================================================================= +// Section 3: Kernel definitions +// ========================================================================= + +// --- 3a. naiveRemoteReadKernel --- +// Each thread computes one C[row,col]. A is read directly from the +// owner rank's shard via remote pointers -- no staging, no gather. +__global__ void naiveRemoteReadKernel( + const __half* const* a_shards, + const __half* b, + __half* c, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank) { + int64_t row = blockIdx.y * blockDim.y + threadIdx.y; + int64_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= m || col >= n) + return; + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) + acc += __half2float(a[lr * k + kk]) * __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); +} + +// --- 3b. threadloadGatherKernel --- +// Two-stage fused kernel with synchronized P2P gather: +// Stage 1: cooperative thread loads copy one A row from the owner +// rank's remote shard into a local staging buffer. +// Stage 2: scalar matmul from staged A. Skipped when n==0 +// (CUTLASS variants launch host-side CUTLASS instead). +// Owner signals readiness; non-owners wait. After compute, readers +// ack completion; owner waits for all readers. +__global__ void threadloadGatherKernel( + const __half* const* a_shards, + __half* a_gathered, + int32_t* const* ready_r, + int32_t* ready_l, + int32_t* const* done_r, + int32_t* done_l, + int64_t rank, + int64_t ws, + int32_t epoch_base, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + const __half* b, + __half* c) { + const int32_t epoch = epoch_base + 1; + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + + // --- Semaphore: owner signals readiness --- + if (threadIdx.x == 0 && rank == owner) + publishToAll(ready_r, ready_l, rank, row, m, ws, epoch); + __syncthreads(); + if (threadIdx.x == 0 && rank != owner) + waitOne(ready_l, row, m, owner, epoch); + __syncthreads(); + + // --- Stage 1: P2P gather via thread loads --- + for (int64_t kk = threadIdx.x; kk < k; kk += blockDim.x) + a_gathered[row * k + kk] = a[lr * k + kk]; + __syncthreads(); + + // --- Stage 2: scalar matmul (skip when n==0) --- + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) + acc += __half2float(a_gathered[row * k + kk]) * + __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); + } + __syncthreads(); + + // --- Semaphore: readers ack, owner waits --- + if (threadIdx.x == 0) { + if (rank == owner) + setLocal(done_l, rank, row, m, epoch); + else + publishToOne(done_r[owner], rank, row, m, epoch); + } + __syncthreads(); + if (threadIdx.x == 0 && rank == owner) + waitAll(done_l, row, m, ws, epoch); + __syncthreads(); + } +} + +// --- 3c. multimemGatherKernel --- +// Two-stage fused kernel using Hopper multimem stores: +// Stage 1: the owner rank writes each A row to a multicast buffer +// via multimem.st.global.v4.f32 (hardware broadcast to all peers). +// Stage 2: scalar matmul from multicast buffer. Skipped when n==0. +// Requires SM90+ and multicast-capable symmetric memory. +__global__ void multimemGatherKernel( + const __half* const* a_shards, + __half* a_mc, + int32_t* const* sem_r, + int32_t* sem_l, + int64_t rank, + int64_t ws, + int32_t epoch_base, + int64_t m, + int64_t n, + int64_t k, + int64_t m_per_rank, + const __half* b, + __half* c) { + for (int64_t row = blockIdx.x; row < m; row += gridDim.x) { + int64_t owner = row / m_per_rank; + int64_t lr = row - owner * m_per_rank; + const __half* a = a_shards[owner]; + __half* arow = a_mc + row * k; + + // --- Stage 1: multimem store (owner only) --- + constexpr int64_t kVec = 8; + int64_t nvec = k / kVec; + if (rank == owner) { + for (int64_t vi = threadIdx.x; vi < nvec; vi += blockDim.x) { + uint4 val = reinterpret_cast(a + lr * k)[vi]; +#if __CUDA_ARCH__ >= 900 + asm volatile( + "multimem.st.global.v4.f32 [%0]," + " {%1, %2, %3, %4};" + : + : "l"((void*)(arow + vi * kVec)), + "f"(__int_as_float((int)val.x)), + "f"(__int_as_float((int)val.y)), + "f"(__int_as_float((int)val.z)), + "f"(__int_as_float((int)val.w)) + : "memory"); +#else + (void)val; + asm volatile("trap;"); +#endif + } + for (int64_t kk = nvec * kVec + threadIdx.x; kk < k; kk += blockDim.x) + arow[kk] = a[lr * k + kk]; + } + __syncthreads(); + + // --- Semaphore barrier --- +#if __CUDA_ARCH__ >= 900 + const int32_t epoch = epoch_base + 1; + if (threadIdx.x == 0 && rank == owner) + publishToAll(sem_r, sem_l, rank, row, m, ws, epoch); + __syncthreads(); + if (threadIdx.x == 0 && rank != owner) + waitOne(sem_l, row, m, owner, epoch); + __syncthreads(); +#else + (void)sem_r; + (void)sem_l; + (void)rank; + (void)ws; + (void)epoch_base; + asm volatile("trap;"); +#endif + + // --- Stage 2: scalar matmul (skip when n==0) --- + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { + float acc = 0.f; + for (int64_t kk = 0; kk < k; ++kk) + acc += __half2float(arow[kk]) * __half2float(b[kk * n + col]); + c[row * n + col] = __float2half(acc); + } + __syncthreads(); + } +} + +} // anonymous namespace + +// ========================================================================= +// Section 4: Public launcher functions +// +// Thin wrappers that set up grid/block dims and launch kernels once. +// Timing and iteration loops live in the .cpp file. +// ========================================================================= + +void launchNaiveRemoteRead(DistributedMatmulContext& ctx) { + constexpr int64_t kB = 16; + dim3 block(kB, kB); + dim3 grid((ctx.n + kB - 1) / kB, (ctx.m + kB - 1) / kB); + naiveRemoteReadKernel<<>>( + ctx.device_remote_ptrs, + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr()), + ctx.m, + ctx.n, + ctx.k, + ctx.m_per_rank); +} + +void launchThreadloadGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute) { + dim3 block(256); + dim3 grid(ctx.m); + threadloadGatherKernel<<>>( + ctx.device_remote_ptrs, + reinterpret_cast<__half*>(ctx.a_gathered.data_ptr()), + ctx.ready_sem_remote, + ctx.ready_sem_local, + ctx.done_sem_remote, + ctx.done_sem_local, + ctx.my_rank, + ctx.world_size, + epoch, + ctx.m, + compute ? ctx.n : 0, + ctx.k, + ctx.m_per_rank, + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr())); +} + +void launchMultimemGather( + DistributedMatmulContext& ctx, + int32_t epoch, + bool compute) { + dim3 block(256); + dim3 grid(ctx.m); + multimemGatherKernel<<>>( + ctx.device_remote_ptrs, + ctx.multicast_ptr, + ctx.stage_sem_remote, + ctx.stage_sem_local, + ctx.my_rank, + ctx.world_size, + epoch, + ctx.m, + compute ? ctx.n : 0, + ctx.k, + ctx.m_per_rank, + reinterpret_cast(ctx.b_full_half.data_ptr()), + reinterpret_cast<__half*>(ctx.c_out_half.data_ptr())); +} + +void matmulTma(at::Tensor& out, const at::Tensor& a, const at::Tensor& b) { + int64_t m = a.size(0), n = b.size(1), k = a.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); +#if defined(NVFUSER_ENABLE_CUTLASS) + if (a.scalar_type() == at::ScalarType::Half) + runGemmSm90(out, a, b, m, n, k, stream); + else + runGemmSm90(out, a, b, m, n, k, stream); +#else + NVF_THROW("CUTLASS support required."); +#endif +} + +bool canRunCutlassCompute(const at::Tensor& a, const at::Tensor& b) { + if (!hasValidTmaShape(a, b)) + return false; +#if !defined(NVFUSER_ENABLE_CUTLASS) || \ + !defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + return false; +#else + auto* props = at::cuda::getDeviceProperties(a.get_device()); + return props->major == 9 && props->minor == 0; +#endif +} + +const char* implName(DistributedMatmulImpl impl) { + switch (impl) { + case DistributedMatmulImpl::baselineNcclAllgatherMatmul: + return "baselineNcclAllgatherMatmul"; + case DistributedMatmulImpl::baselineCudaAllgatherMatmul: + return "baselineCudaAllgatherMatmul"; + case DistributedMatmulImpl::naiveRemoteRead: + return "naiveRemoteRead"; + case DistributedMatmulImpl::threadloadGatherScalarCompute: + return "threadloadGatherScalarCompute"; + case DistributedMatmulImpl::multimemGatherScalarCompute: + return "multimemGatherScalarCompute"; + case DistributedMatmulImpl::threadloadGatherThenCutlass: + return "threadloadGatherThenCutlass"; + case DistributedMatmulImpl::multimemGatherThenCutlass: + return "multimemGatherThenCutlass"; + } + return "unknown"; +} + +bool isMulticastSupported(int64_t device_id) { + int val = 0; + auto r = cuDeviceGetAttribute( + &val, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + static_cast(device_id)); + return r == CUDA_SUCCESS && val != 0; +} + +} // namespace nvfuser