You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a Hopper TMA (cp.async.bulk) copy kernel (csrc/multidevice/tma_copy.cu) compiled at runtime via NVRTC, and wire it as an alternative P2P data transport alongside the existing copy-engine (cudaMemcpyAsync) path.
Add P2pTransport option (NVFUSER_ENABLE=p2p_transport(tma)) that switches sendPost/recvPost in cuda_p2p.cpp between copy-engine (default) and TMA.
The static module and kernel variables in launchTmaCopy() are not thread-safe. Multiple threads could simultaneously enter the initialization block, potentially causing race conditions during NVRTC compilation and CUDA module loading. Consider adding mutex protection or std::call_once for thread-safe initialization.
If NVRTC compilation fails (lines 398-407), the nvrtcProgram is destroyed but there's no cleanup path if cuModuleLoadData or cuModuleGetFunction fail. Consider adding proper error handling with nvrtcDestroyProgram in all error paths to prevent resource leaks.
The PR lacks performance comparison data between TMA and copy engine transports. Consider adding benchmarking results or performance metrics to validate that TMA provides expected benefits over the default copy engine, especially for different transfer sizes and patterns.
if (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
// Signals completionWriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
break;
}
case P2pProtocol::Put: {
WriteValue32ToLocalAndPeer(
stream, ipc_handles, IpcSemaphore::kInProgress);
break;
}
default:
NVF_ERROR("Invalid P2P protocol: ", protocol);
}
}
voidrecvWait(const P2pIpcHandle& ipc_handles, CUstream stream) {
P2pProtocol protocol = getP2pProtocol();
switch (protocol) {
case P2pProtocol::Put:
NVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
stream,
reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
(cuuint32_t)(IpcSemaphore::kIdle),
CU_STREAM_WAIT_VALUE_EQ));
break;
case P2pProtocol::Get:
break;
default:
NVF_ERROR("Invalid P2P protocol: ", protocol);
}
}
voidsendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
P2pProtocol protocol = getP2pProtocol();
switch (protocol) {
case P2pProtocol::Get:
// signal to self and peer that transfer is in progressWriteValue32ToLocalAndPeer(
stream, ipc_handles, IpcSemaphore::kInProgress);
break;
case P2pProtocol::Put: {
// wait for receiver to be readyNVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
stream,
reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
(cuuint32_t)(IpcSemaphore::kInProgress),
CU_STREAM_WAIT_VALUE_EQ));
// Put the data to the receiverif (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
Adds Hopper TMA (cp.async.bulk) as an alternative P2P transport alongside the existing copy-engine path, controlled by NVFUSER_ENABLE=p2p_transport(tma).
Key changes:
Implements launchTmaCopy() that compiles tma_copy.cu at runtime via NVRTC and handles arbitrary sizes via chunking
Adds P2pTransport enum (CopyEngine, Tma) and integrates transport selection into sendPost/recvPost
Enforces Compute Capability >= 9.0 check and 16-byte alignment requirement
Simplifies test suite by removing duplicate test-only TMA implementation
Issues found:
Thread-safety issue in static initialization of TMA kernel module (also affects existing kernels)
Confidence Score: 3/5
Safe for merging with awareness of pre-existing thread-safety pattern
The thread-safety issue in static initialization is a real concern, but it follows the same pattern as existing kernels in the file (launchAlltoallvKernel, launchMulticastKernel). If those haven't caused issues in practice, this likely won't either. The implementation is well-structured with proper error handling and alignment checks.
Pay attention to csrc/multidevice/cuda_p2p.cpp due to the static initialization pattern
Important Files Changed
Filename
Overview
csrc/multidevice/cuda_p2p.cpp
Adds TMA copy kernel compilation and transport switching logic; potential thread-safety issue in static initialization
csrc/multidevice/cuda_p2p.h
Adds P2pTransport enum and launchTmaCopy declaration; clean interface additions
csrc/options.cpp
Adds p2p_transport option to enable map; straightforward configuration change
csrc/options.h
Adds P2pTransport enum value with documentation; clean enum addition
tests/cpp/test_multidevice_tma.cpp
Simplifies tests to use production launchTmaCopy; removes duplicate test-only implementation
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[P2P Send/Recv Operation] --> B{getP2pTransport}
B -->|CopyEngine default| C[cudaMemcpyAsync]
B -->|Tma NVFUSER_ENABLE=p2p_transport tma| D[launchTmaCopy]
D --> E{Module initialized?}
E -->|No| F[Compile tma_copy.cu via NVRTC]
F --> G[Check SM >= 9.0]
G --> H[Cache CUmodule & CUfunction]
E -->|Yes| I[Use cached kernel]
H --> I
I --> J{Size > 48KB chunk?}
J -->|Yes| K[Split into chunks]
J -->|No| L[Single launch]
K --> M[Launch multiple TMA kernels]
L --> N[Launch TMA kernel]
M --> O[GMEM -> SMEM -> GMEM]
N --> O
C --> P[Direct device-to-device copy]
O --> Q[Complete]
P --> Q
The reason will be displayed to describe this comment to others. Learn more.
Static initialization lacks thread-safety protection. Multiple threads calling launchTmaCopy concurrently could race on the module == nullptr check (line 365), causing duplicate compilations or accessing partially-initialized state.
Other kernels in this file (launchAlltoallvKernel, launchMulticastKernel) have the same pattern. Consider adding mutex protection or using std::call_once for thread-safe lazy initialization if concurrent calls are possible.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
cp.async.bulk) copy kernel (csrc/multidevice/tma_copy.cu) compiled at runtime via NVRTC, and wire it as an alternative P2P data transport alongside the existing copy-engine (cudaMemcpyAsync) path.P2pTransportoption (NVFUSER_ENABLE=p2p_transport(tma)) that switchessendPost/recvPostincuda_p2p.cppbetween copy-engine (default) and TMA.