From 13f72a5fb5c098cad5200d7fc1c6d8b237991432 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 25 Feb 2026 07:54:13 -0800 Subject: [PATCH] Add TMA bulk copy kernel and P2P transport option --- csrc/multidevice/cuda_p2p.cpp | 154 ++++++++++++++++++++++++++--- csrc/multidevice/cuda_p2p.h | 16 ++- csrc/options.cpp | 1 + csrc/options.h | 1 + tests/cpp/test_multidevice_tma.cpp | 132 ++----------------------- 5 files changed, 168 insertions(+), 136 deletions(-) diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index 8804c1a7a79..870802e1901 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -8,6 +8,7 @@ #include "multidevice/cuda_p2p.h" #include "nvfuser_resources/alltoallv.h" #include "nvfuser_resources/multicast.h" +#include "nvfuser_resources/tma_copy.h" #include "cuda_utils.h" #include "multidevice/ipc_handle.h" @@ -34,6 +35,22 @@ P2pProtocol getP2pProtocol() { : P2pProtocol::Get; } +std::ostream& operator<<(std::ostream& os, P2pTransport transport) { + switch (transport) { + case P2pTransport::CopyEngine: + return os << "CopyEngine"; + case P2pTransport::Tma: + return os << "Tma"; + } + std::unreachable(); +} + +P2pTransport getP2pTransport() { + return hasEnableOptionArgument(EnableOption::P2pTransport, "tma") + ? P2pTransport::Tma + : P2pTransport::CopyEngine; +} + namespace { void launchAlltoallvKernel( const void* send, @@ -335,6 +352,103 @@ void launchMulticastKernel( kernel, blocks, 1, 1, threads, 1, 1, 0, stream, args_kernel, nullptr)); } +} // anonymous namespace + +void launchTmaCopy( + void* dst, + const void* src, + size_t size, + CUstream stream) { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + + if (module == nullptr) { + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, + nvfuser_resources::tma_copy_cu, + "tma_copy.cu", + 0, + nullptr, + nullptr)); + + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaGetDeviceProperties(&prop, device)); + + NVF_CHECK( + prop.major >= 9, + "TMA transport requires Compute Capability >= 9.0 (Hopper+). " + "Current device ", + device, + " is Compute Capability ", + prop.major, + ".", + prop.minor); + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(prop.major) + std::to_string(prop.minor); + std::vector opts = { + arch_arg.c_str(), "--std=c++17"}; + + nvrtcResult res = + nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL( + nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL( + nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR( + false, "TMA kernel compilation failed:\n", log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + NVFUSER_CUDA_SAFE_CALL( + cuModuleLoadData(&module, ptx.data())); + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&kernel, module, "tma_copy_1d")); + } + + NVF_CHECK(size % 16 == 0, "TMA requires size to be a multiple of 16"); + + // The kernel stages data through shared memory, so we chunk large + // transfers. 48 KB is the guaranteed default dynamic smem limit + // (using more would require cuFuncSetAttribute opt-in). + constexpr int kDefaultSmem = 48 * 1024; + constexpr int kMbarrierBytes = 8; + constexpr int max_chunk = + ((kDefaultSmem - kMbarrierBytes) / 16) * 16; + + auto* dst_bytes = static_cast(dst); + auto* src_bytes = static_cast(src); + size_t remaining = size; + + while (remaining > 0) { + int chunk = static_cast( + std::min(remaining, static_cast(max_chunk))); + int smem_size = chunk + static_cast(sizeof(uint64_t)); + void* d = dst_bytes; + const void* s = src_bytes; + void* args[] = {&d, &s, &chunk}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + kernel, 1, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr)); + dst_bytes += chunk; + src_bytes += chunk; + remaining -= chunk; + } +} + +namespace { + // We choose duplicate the state of the semaphore on both the local and peer // devices to avoid cuStreamWaitValue32 to poll on a remote buffer and pollutes // the network. This is a theoretical consideration that we have not proved or @@ -676,12 +790,20 @@ void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) { (cuuint32_t)(IpcSemaphore::kInProgress), CU_STREAM_WAIT_VALUE_EQ)); // Get the data from the sender - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( - ipc_handles.local().ptr(), - ipc_handles.peer().ptr(), - count, - cudaMemcpyDeviceToDevice, - stream)); + if (getP2pTransport() == P2pTransport::Tma) { + launchTmaCopy( + ipc_handles.local().ptr(), + ipc_handles.peer().ptr(), + count, + stream); + } else { + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( + ipc_handles.local().ptr(), + ipc_handles.peer().ptr(), + count, + cudaMemcpyDeviceToDevice, + stream)); + } // Signals completion WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle); break; @@ -729,12 +851,20 @@ void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) { (cuuint32_t)(IpcSemaphore::kInProgress), CU_STREAM_WAIT_VALUE_EQ)); // Put the data to the receiver - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( - ipc_handles.peer().ptr(), - ipc_handles.local().ptr(), - count, - cudaMemcpyDeviceToDevice, - stream)); + if (getP2pTransport() == P2pTransport::Tma) { + launchTmaCopy( + ipc_handles.peer().ptr(), + ipc_handles.local().ptr(), + count, + stream); + } else { + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( + ipc_handles.peer().ptr(), + ipc_handles.local().ptr(), + count, + cudaMemcpyDeviceToDevice, + stream)); + } WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle); break; } diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 514195c0746..234054ed8eb 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -20,8 +20,20 @@ P2pProtocol getP2pProtocol(); std::ostream& operator<<(std::ostream& os, P2pProtocol protocol); -// Returns the prescribed P2P protocol based on NVFUSER_ENABLE option -P2pProtocol getP2pProtocol(); +enum class P2pTransport { CopyEngine, Tma }; + +P2pTransport getP2pTransport(); + +std::ostream& operator<<(std::ostream& os, P2pTransport transport); + +//! TMA 1D bulk copy: GMEM(src) -> SMEM -> GMEM(dst). +//! Compiled at runtime via NVRTC from csrc/multidevice/tma_copy.cu. +//! Handles arbitrarily large sizes by chunking to fit shared memory. +void launchTmaCopy( + void* dst, + const void* src, + size_t size, + CUstream stream); void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream); diff --git a/csrc/options.cpp b/csrc/options.cpp index 14dddd89eec..7c7c0e28bff 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -180,6 +180,7 @@ const std::unordered_map& getEnableOptions() { {"insert_resharding_after", EnableOption::InsertReshardingAfter}, {"fast_math", EnableOption::FastMath}, {"p2p_protocol", EnableOption::P2pProtocol}, + {"p2p_transport", EnableOption::P2pTransport}, {"multicast_protocol", EnableOption::MulticastProtocol}, {"parallel_serde", EnableOption::ParallelSerde}, }; diff --git a/csrc/options.h b/csrc/options.h index 6dad3909997..dc908044e31 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -124,6 +124,7 @@ enum class EnableOption { InsertReshardingAfter, //! Insert resharding set after the expression FastMath, //! Enable fast math optimizations (--use_fast_math) P2pProtocol, //! Prescribe P2P protocol: put|get + P2pTransport, //! Prescribe P2P data transport: ce|tma (default: ce) MulticastProtocol, //! Prescribe multicast protocol: //! memcpy|multimem|batch_memcpy ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel diff --git a/tests/cpp/test_multidevice_tma.cpp b/tests/cpp/test_multidevice_tma.cpp index 0a5363eae3f..5c13c648c85 100644 --- a/tests/cpp/test_multidevice_tma.cpp +++ b/tests/cpp/test_multidevice_tma.cpp @@ -6,128 +6,28 @@ */ // clang-format on // -// Unit tests for Hopper TMA (Tensor Memory Accelerator) 1D bulk copy +// Unit tests for TMA (Tensor Memory Accelerator) 1D bulk copy // (cp.async.bulk) across different memory sources: -// 1. Local device memory (cudaMalloc) +// 1. Local device memory // 2. VMM-mapped peer device memory (inter-device P2P) -// 3. NVLS multicast unicast pointers +// 3. NVLS multicast pointers // -// The kernel source lives in csrc/multidevice/tma_copy.cu and is -// stringified at build time. It is compiled at runtime via NVRTC, -// same pattern as csrc/multidevice/cuda_p2p.cpp. +// Uses the production launchTmaCopy() from cuda_p2p.cpp, which +// compiles csrc/multidevice/tma_copy.cu at runtime via NVRTC. #include -#include - -#include -#include #include "cuda_utils.h" #include "driver_api.h" -#include "exceptions.h" +#include "multidevice/cuda_p2p.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" -#include "nvfuser_resources/tma_copy.h" #include "tests/cpp/multidevice.h" namespace nvfuser { -// ============================================================================ -// NVRTC helper: compile kernel source at runtime, cache the result. -// ============================================================================ - -namespace { - -CUfunction compileAndGetKernel( - CUmodule& module, - CUfunction& function, - const char* source, - const char* source_name, - const char* kernel_name) { - if (function != nullptr) { - return function; - } - - nvrtcProgram prog; - NVFUSER_NVRTC_SAFE_CALL( - nvrtcCreateProgram(&prog, source, source_name, 0, nullptr, nullptr)); - - int device = 0; - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); - cudaDeviceProp prop; - NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); - - std::string arch_arg = "--gpu-architecture=compute_" + - std::to_string(prop.major) + std::to_string(prop.minor); - std::vector opts = {arch_arg.c_str(), "--std=c++17"}; - - nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); - if (res != NVRTC_SUCCESS) { - size_t logSize; - NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); - std::vector log(logSize); - NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); - NVF_ERROR( - false, - "NVRTC compilation of '", - source_name, - "' failed:\n", - log.data()); - } - - size_t ptxSize; - NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); - std::vector ptx(ptxSize); - NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); - NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); - - NVFUSER_CUDA_SAFE_CALL(cuModuleLoadData(&module, ptx.data())); - NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(&function, module, kernel_name)); - - return function; -} - -//! Return the NVRTC-compiled tma_copy_1d CUfunction (cached after -//! first call). The kernel uses cp.async.bulk to perform -//! GMEM(src) -> SMEM -> GMEM(dst) -//! and requires dynamic shared memory of num_bytes + 8 (mbarrier). -CUfunction getTmaCopy1dKernel() { - static CUmodule module = nullptr; - static CUfunction kernel = nullptr; - return compileAndGetKernel( - module, - kernel, - nvfuser_resources::tma_copy_cu, - "tma_copy.cu", - "tma_copy_1d"); -} - -//! Launch the TMA 1D bulk copy kernel: GMEM(src) -> SMEM -> GMEM(dst). -//! num_bytes must be > 0 and a multiple of 16. -void launchTmaCopy1D( - void* dst, - const void* src, - int num_bytes, - CUstream stream = nullptr) { - NVF_CHECK(num_bytes > 0 && num_bytes % 16 == 0); - CUfunction tma_kernel = getTmaCopy1dKernel(); - int smem_size = num_bytes + static_cast(sizeof(uint64_t)); - void* args[] = {&dst, &src, &num_bytes}; - NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( - tma_kernel, 1, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr)); -} - -} // anonymous namespace - -// ============================================================================ -// Tests -// ============================================================================ - using TmaTest = MultiDeviceTest; -// Verify TMA 1D bulk copy on local device memory. -// The kernel uses cp.async.bulk (GMEM->SMEM) + cp.async.bulk (SMEM->GMEM) -// with mbarrier synchronization between the two phases. TEST_F(TmaTest, TmaLocalCopy) { const int64_t local_rank = communicator_->local_rank(); @@ -149,15 +49,12 @@ TEST_F(TmaTest, TmaLocalCopy) { at::Tensor src = at::arange(kNumElems, options); at::Tensor dst = at::zeros({kNumElems}, options); - launchTmaCopy1D(dst.data_ptr(), src.data_ptr(), kSizeBytes); + launchTmaCopy(dst.data_ptr(), src.data_ptr(), kSizeBytes, nullptr); NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); EXPECT_TRUE(dst.equal(src)); } -// Verify TMA 1D bulk copy reading from a VMM-mapped peer device -// buffer. SymmetricTensor handles the VMM allocation and IPC handle -// exchange; the test focuses on the TMA transfer itself. TEST_F(TmaTest, TmaInterDeviceCopy) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; @@ -191,7 +88,7 @@ TEST_F(TmaTest, TmaInterDeviceCopy) { {kNumElems}, at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank)); - launchTmaCopy1D(output.data_ptr(), peer.data_ptr(), kSizeBytes); + launchTmaCopy(output.data_ptr(), peer.data_ptr(), kSizeBytes, nullptr); NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); at::Tensor expected = at::full( @@ -205,10 +102,6 @@ TEST_F(TmaTest, TmaInterDeviceCopy) { #if (CUDA_VERSION >= 13000) -// Verify TMA 1D bulk copy writing TO an NVLS multicast pointer. -// Root uses TMA to write data to the MC pointer, which broadcasts -// via NVLS hardware. All ranks then verify the data arrived by -// reading from their local UC view with a normal copy. TEST_F(TmaTest, TmaMulticastWrite) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; @@ -233,11 +126,8 @@ TEST_F(TmaTest, TmaMulticastWrite) { GTEST_SKIP() << "Device does not support Multicast Objects; skipping."; } - constexpr int64_t kNumElems = 524288; // 2 MB / sizeof(int32_t) + constexpr int64_t kNumElems = 524288; constexpr int64_t root = 0; - - // cp.async.bulk transfer size is limited by shared memory, - // so we broadcast a 4 KB slice via TMA. constexpr int kTmaBytes = 4096; static_assert(kTmaBytes % 16 == 0); constexpr int kTmaElems = kTmaBytes / sizeof(int32_t); @@ -250,16 +140,14 @@ TEST_F(TmaTest, TmaMulticastWrite) { auto opts = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank); - // Root: TMA-write source data to MC pointer (NVLS broadcasts it) if (rank == root) { at::Tensor src = at::arange(kTmaElems, opts); - launchTmaCopy1D(sym.multicastPtr(), src.data_ptr(), kTmaBytes); + launchTmaCopy(sym.multicastPtr(), src.data_ptr(), kTmaBytes, nullptr); NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); } communicator_->barrier(); - // All ranks: verify data arrived via normal read of local UC tensor at::Tensor readback = sym.localTensor().slice(0, 0, kTmaElems).clone(); at::Tensor expected = at::arange(kTmaElems, opts); EXPECT_TRUE(readback.equal(expected))