Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 142 additions & 12 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "multidevice/cuda_p2p.h"
#include "nvfuser_resources/alltoallv.h"
#include "nvfuser_resources/multicast.h"
#include "nvfuser_resources/tma_copy.h"

#include "cuda_utils.h"
#include "multidevice/ipc_handle.h"
Expand All @@ -34,6 +35,22 @@ P2pProtocol getP2pProtocol() {
: P2pProtocol::Get;
}

std::ostream& operator<<(std::ostream& os, P2pTransport transport) {
switch (transport) {
case P2pTransport::CopyEngine:
return os << "CopyEngine";
case P2pTransport::Tma:
return os << "Tma";
}
std::unreachable();
}

P2pTransport getP2pTransport() {
return hasEnableOptionArgument(EnableOption::P2pTransport, "tma")
? P2pTransport::Tma
: P2pTransport::CopyEngine;
}

namespace {
void launchAlltoallvKernel(
const void* send,
Expand Down Expand Up @@ -335,6 +352,103 @@ void launchMulticastKernel(
kernel, blocks, 1, 1, threads, 1, 1, 0, stream, args_kernel, nullptr));
}

} // anonymous namespace

void launchTmaCopy(
void* dst,
const void* src,
size_t size,
CUstream stream) {
static CUmodule module = nullptr;
static CUfunction kernel = nullptr;
Comment on lines +362 to +363
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static initialization lacks thread-safety protection. Multiple threads calling launchTmaCopy concurrently could race on the module == nullptr check (line 365), causing duplicate compilations or accessing partially-initialized state.

Other kernels in this file (launchAlltoallvKernel, launchMulticastKernel) have the same pattern. Consider adding mutex protection or using std::call_once for thread-safe lazy initialization if concurrent calls are possible.


if (module == nullptr) {
nvrtcProgram prog;
NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
&prog,
nvfuser_resources::tma_copy_cu,
"tma_copy.cu",
0,
nullptr,
nullptr));

int device = 0;
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
cudaDeviceProp prop;
NVFUSER_CUDA_RT_SAFE_CALL(
cudaGetDeviceProperties(&prop, device));

NVF_CHECK(
prop.major >= 9,
"TMA transport requires Compute Capability >= 9.0 (Hopper+). "
"Current device ",
device,
" is Compute Capability ",
prop.major,
".",
prop.minor);

std::string arch_arg = "--gpu-architecture=compute_" +
std::to_string(prop.major) + std::to_string(prop.minor);
std::vector<const char*> opts = {
arch_arg.c_str(), "--std=c++17"};

nvrtcResult res =
nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
if (res != NVRTC_SUCCESS) {
size_t logSize;
NVFUSER_NVRTC_SAFE_CALL(
nvrtcGetProgramLogSize(prog, &logSize));
std::vector<char> log(logSize);
NVFUSER_NVRTC_SAFE_CALL(
nvrtcGetProgramLog(prog, log.data()));
NVF_ERROR(
false, "TMA kernel compilation failed:\n", log.data());
}

size_t ptxSize;
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
std::vector<char> ptx(ptxSize);
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));

NVFUSER_CUDA_SAFE_CALL(
cuModuleLoadData(&module, ptx.data()));
NVFUSER_CUDA_SAFE_CALL(
cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
}

NVF_CHECK(size % 16 == 0, "TMA requires size to be a multiple of 16");

// The kernel stages data through shared memory, so we chunk large
// transfers. 48 KB is the guaranteed default dynamic smem limit
// (using more would require cuFuncSetAttribute opt-in).
constexpr int kDefaultSmem = 48 * 1024;
constexpr int kMbarrierBytes = 8;
constexpr int max_chunk =
((kDefaultSmem - kMbarrierBytes) / 16) * 16;

auto* dst_bytes = static_cast<char*>(dst);
auto* src_bytes = static_cast<const char*>(src);
size_t remaining = size;

while (remaining > 0) {
int chunk = static_cast<int>(
std::min(remaining, static_cast<size_t>(max_chunk)));
int smem_size = chunk + static_cast<int>(sizeof(uint64_t));
void* d = dst_bytes;
const void* s = src_bytes;
void* args[] = {&d, &s, &chunk};
NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel(
kernel, 1, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr));
dst_bytes += chunk;
src_bytes += chunk;
remaining -= chunk;
}
}

namespace {

// We choose duplicate the state of the semaphore on both the local and peer
// devices to avoid cuStreamWaitValue32 to poll on a remote buffer and pollutes
// the network. This is a theoretical consideration that we have not proved or
Expand Down Expand Up @@ -676,12 +790,20 @@ void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
(cuuint32_t)(IpcSemaphore::kInProgress),
CU_STREAM_WAIT_VALUE_EQ));
// Get the data from the sender
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
if (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
// Signals completion
WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
break;
Expand Down Expand Up @@ -729,12 +851,20 @@ void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
(cuuint32_t)(IpcSemaphore::kInProgress),
CU_STREAM_WAIT_VALUE_EQ));
// Put the data to the receiver
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
if (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
break;
}
Expand Down
16 changes: 14 additions & 2 deletions csrc/multidevice/cuda_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@ P2pProtocol getP2pProtocol();

std::ostream& operator<<(std::ostream& os, P2pProtocol protocol);

// Returns the prescribed P2P protocol based on NVFUSER_ENABLE option
P2pProtocol getP2pProtocol();
enum class P2pTransport { CopyEngine, Tma };

P2pTransport getP2pTransport();

std::ostream& operator<<(std::ostream& os, P2pTransport transport);

//! TMA 1D bulk copy: GMEM(src) -> SMEM -> GMEM(dst).
//! Compiled at runtime via NVRTC from csrc/multidevice/tma_copy.cu.
//! Handles arbitrarily large sizes by chunking to fit shared memory.
void launchTmaCopy(
void* dst,
const void* src,
size_t size,
CUstream stream);

void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream);

Expand Down
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"insert_resharding_after", EnableOption::InsertReshardingAfter},
{"fast_math", EnableOption::FastMath},
{"p2p_protocol", EnableOption::P2pProtocol},
{"p2p_transport", EnableOption::P2pTransport},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"parallel_serde", EnableOption::ParallelSerde},
};
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ enum class EnableOption {
InsertReshardingAfter, //! Insert resharding set after the expression
FastMath, //! Enable fast math optimizations (--use_fast_math)
P2pProtocol, //! Prescribe P2P protocol: put|get
P2pTransport, //! Prescribe P2P data transport: ce|tma (default: ce)
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
Expand Down
Loading
Loading