Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,20 @@ void Communicator::barrier(std::optional<CommunicatorBackend> backend) {
getWorld(backend)->barrier(options)->wait();
}

#ifdef NVFUSER_DISTRIBUTED
c10::intrusive_ptr<c10d::Store> Communicator::getStore() const {
return c10::intrusive_ptr<c10d::Store>(store_);
}

c10::intrusive_ptr<c10d::Backend> Communicator::getWorldBackendIntrusivePtr(
std::optional<CommunicatorBackend> backend) {
std::vector<RankType> all_ranks(size_);
std::iota(all_ranks.begin(), all_ranks.end(), 0);
CommunicatorBackend b = backend.value_or(default_backend_);
std::string team_key = getTeamKey(all_ranks, b);
(void)getBackendForTeam(all_ranks, backend, "");
return backends_.at(team_key);
}
#endif

} // namespace nvfuser
13 changes: 13 additions & 0 deletions csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#else
Expand Down Expand Up @@ -124,6 +125,18 @@ class NVF_API Communicator {
return store_.get();
}

#ifdef NVFUSER_DISTRIBUTED
// Returns the store as an intrusive_ptr for use with PyTorch symmetric
// memory (c10d::symmetric_memory::set_group_info).
c10::intrusive_ptr<c10d::Store> getStore() const;

// Returns the world backend as an intrusive_ptr so it can be registered with
// c10d::register_process_group (e.g. for PyTorch symmetric memory NCCL
// rendezvous, which resolves the group by name).
c10::intrusive_ptr<c10d::Backend> getWorldBackendIntrusivePtr(
std::optional<CommunicatorBackend> backend = std::nullopt);
#endif

private:
Communicator(
CommunicatorBackend backend = comm_backend_default,
Expand Down
18 changes: 18 additions & 0 deletions csrc/multidevice/ipc_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,22 @@ MulticastProtocol getMulticastProtocol() {
return MulticastProtocol::BatchMemcpy;
}

SymmetricMemoryBackend getSymmetricMemoryBackend() {
if (isOptionEnabled(EnableOption::SymmetricMemoryBackend)) {
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nccl")) {
return SymmetricMemoryBackend::PyTorchNccl;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nvshmem")) {
return SymmetricMemoryBackend::PyTorchNvshmem;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_cuda")) {
return SymmetricMemoryBackend::PyTorchCuda;
}
}
return SymmetricMemoryBackend::Native;
}

} // namespace nvfuser
13 changes: 13 additions & 0 deletions csrc/multidevice/ipc_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ enum class MulticastProtocol { Memcpy, Multimem, BatchMemcpy };

MulticastProtocol getMulticastProtocol();

// Backend for symmetric memory allocation and rendezvous.
// Native: Fuser's own CUDA VMM + IPC implementation (default, maintained).
// PyTorch*: Use PyTorch's symmetric memory (torch.distributed._symmetric_memory)
// with the given transport backend (Nccl, Nvshmem, or Cuda).
enum class SymmetricMemoryBackend {
Native,
PyTorchNccl,
PyTorchNvshmem,
PyTorchCuda,
};

SymmetricMemoryBackend getSymmetricMemoryBackend();

// Creates a listening Unix domain socket bound to path.
// If path starts with '@', it uses the abstract namespace (replaced with \0).
// Returns the socket file descriptor.
Expand Down
163 changes: 162 additions & 1 deletion csrc/multidevice/symmetric_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include "multidevice/symmetric_tensor.h"

#include <mutex>
#include <numeric>

#include "cuda_utils.h"
Expand All @@ -15,10 +16,63 @@
#include "multidevice/ipc_utils.h"
#include "multidevice/utils.h"

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
#endif

namespace nvfuser {

namespace {

#ifdef NVFUSER_DISTRIBUTED
const char* kPyTorchSymmMemGroupName = "nvfuser_symm";

// Cache: tensor storage data ptr -> PyTorch SymmetricMemory handle from
// rendezvous. Used so SymmetricTensor(tensor) can recover the handle.
std::unordered_map<void*, c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory>>&
getPySymmHandleCache() {
static std::unordered_map<
void*,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory>>
cache;
return cache;
}

std::mutex& getPySymmHandleCacheMutex() {
static std::mutex m;
return m;
}

void ensurePyTorchSymmMemBackend(SymmetricMemoryBackend backend) {
static std::once_flag once;
std::call_once(once, [backend]() {
const char* name = nullptr;
switch (backend) {
case SymmetricMemoryBackend::PyTorchNccl:
name = "NCCL";
break;
case SymmetricMemoryBackend::PyTorchNvshmem:
name = "NVSHMEM";
break;
case SymmetricMemoryBackend::PyTorchCuda:
name = "CUDA";
break;
default:
NVF_ERROR(false, "Unexpected PyTorch symmetric memory backend");
}
c10d::symmetric_memory::set_backend(name);
Communicator& comm = Communicator::getInstance();
NVF_CHECK(comm.is_available(), "Communicator not available for symmetric memory");
c10d::symmetric_memory::set_group_info(
kPyTorchSymmMemGroupName,
static_cast<int>(comm.deviceId()),
static_cast<int>(comm.size()),
comm.getStore());
});
}
#endif


// Returns the allocation granularity for symmetric memory.
// - query_mcast_granularity: if true, considers multicast granularity
// - query_mcast_recommended_granularity: if true, uses recommended (larger)
Expand Down Expand Up @@ -88,6 +142,48 @@ at::Tensor SymmetricTensor::allocate(
at::IntArrayRef sizes,
at::ScalarType dtype,
at::Device device) {
SymmetricMemoryBackend backend = getSymmetricMemoryBackend();

#ifdef NVFUSER_DISTRIBUTED
if (backend != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(backend);
std::vector<int64_t> strides(sizes.size());
strides.back() = 1;
for (int64_t i = (int64_t)strides.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
// NCCLSymmetricMemoryAllocator::alloc must not be called with a group_name.
c10::optional<std::string> alloc_group_name =
(backend == SymmetricMemoryBackend::PyTorchNccl)
? c10::nullopt
: c10::optional<std::string>(kPyTorchSymmMemGroupName);
at::Tensor tensor = c10d::symmetric_memory::empty_strided_p2p(
sizes,
strides,
dtype,
device,
alloc_group_name,
c10::nullopt);
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> handle =
c10d::symmetric_memory::rendezvous(
tensor, c10::optional<std::string>(kPyTorchSymmMemGroupName));
void* key = tensor.storage().data_ptr().get();
{
std::lock_guard<std::mutex> lock(getPySymmHandleCacheMutex());
getPySymmHandleCache()[key] = handle;
}
return tensor;
}
#else
if (backend != SymmetricMemoryBackend::Native) {
NVF_ERROR(
false,
"PyTorch symmetric memory backend requires a build with "
"NVFUSER_DISTRIBUTED. Use NVFUSER_ENABLE=symmetric_memory_backend(native) "
"or do not set symmetric_memory_backend.");
}
#endif

int is_vmm_supported;
NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute(
&is_vmm_supported,
Expand Down Expand Up @@ -212,6 +308,28 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor)
"Expected CUDA tensor, got: ",
local_tensor.device());

#ifdef NVFUSER_DISTRIBUTED
{
void* key = local_tensor.storage().data_ptr().get();
std::lock_guard<std::mutex> lock(getPySymmHandleCacheMutex());
auto& cache = getPySymmHandleCache();
auto it = cache.find(key);
if (it != cache.end()) {
py_symm_handle_ = std::move(it->second);
cache.erase(it);
world_size_ = py_symm_handle_->get_world_size();
my_device_id_ = py_symm_handle_->get_rank();
requested_size_ = local_tensor.numel() * local_tensor.element_size();
are_remote_tensors_setup_ = true; // PyTorch rendezvous already set up
if (py_symm_handle_->has_multicast_support()) {
is_multicast_setup_ = true;
mc_ptr_ = py_symm_handle_->get_multicast_ptr();
}
return;
}
}
#endif

std::string error = SymmetricTensor::validate(local_tensor);
NVF_CHECK(error.empty(), "Invalid symmetric allocation: ", error);

Expand Down Expand Up @@ -253,6 +371,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor)
}

SymmetricTensor::~SymmetricTensor() {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return; // PyTorch backend: no native VMM cleanup
}
#endif
#if (CUDA_VERSION >= 13000)
if (is_multicast_setup_) {
if (mc_base_ptr_) {
Expand Down Expand Up @@ -302,6 +425,11 @@ void SymmetricTensor::setupRemoteHandles(const std::string& tag) {
if (are_remote_tensors_setup_ == true) {
return;
}
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return; // PyTorch backend: rendezvous already established remote access
}
#endif
Communicator& comm = Communicator::getInstance();
CUmemGenericAllocationHandle local_handle = alloc_handles_[my_device_id_];
CUdeviceptr local_ptr = remote_ptrs_[my_device_id_];
Expand Down Expand Up @@ -379,6 +507,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const {
return local_tensor_;
}

#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return py_symm_handle_->get_remote_tensor(
rank, local_tensor_.sizes(), local_tensor_.scalar_type());
}
#endif

NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup");
return at::from_blob(
reinterpret_cast<void*>(remote_ptrs_[rank]),
Expand All @@ -390,6 +525,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const {
}

void* SymmetricTensor::multicastPtr() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return py_symm_handle_->has_multicast_support()
? py_symm_handle_->get_multicast_ptr()
: nullptr;
}
#endif
NVF_CHECK(is_multicast_setup_, "Multicast not setup");
return mc_ptr_;
}
Expand All @@ -398,7 +540,14 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) {
if (is_contiguous_view_setup_) {
return;
}

#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
NVF_ERROR(
false,
"Contiguous view is not yet supported for PyTorch symmetric memory backend. "
"Use native backend for SymmetricContiguousView.");
}
#endif
NVF_CHECK(
are_remote_tensors_setup_ == true,
"Remote tensors must be setup before setupContiguousView");
Expand Down Expand Up @@ -462,13 +611,25 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) {
}

at::Tensor SymmetricTensor::getContiguousView() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
NVF_ERROR(
false,
"Contiguous view is not yet supported for PyTorch symmetric memory backend.");
}
#endif
NVF_CHECK(is_contiguous_view_setup_, "Contiguous view not setup");
return contiguous_view_;
}

void SymmetricTensor::setupMulticast(
int64_t exporter_rank,
const std::string& tag) {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return; // PyTorch backend: multicast handled by backend if supported
}
#endif
#if (CUDA_VERSION >= 13000)
if (is_multicast_setup_) {
return;
Expand Down
21 changes: 15 additions & 6 deletions csrc/multidevice/symmetric_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <ATen/core/Tensor.h>
#include <cuda.h>

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
#endif

namespace nvfuser {

// SymmetricTensor wraps a local symmetric memory allocation and enables:
Expand All @@ -18,13 +22,14 @@ namespace nvfuser {
// - Contiguous view creation across all ranks
//
// Design: Decouples local allocation from IPC handle exchange for better
// interoperability and support for pre-allocated user buffers
// interoperability and support for pre-allocated user buffers.
//
// TODO: Long term plan is to integrate pytorch's native symmetric memory as a
// possible backend. One important reason to use pytorch's allocator is to use
// pytorch's memory pool to let the framework own the memory stack and not
// further fragment the memory. On the other hand, having our own implementation
// allows us to experiment more advanced features like contigous view creation.
// Backends (see SymmetricMemoryBackend in ipc_utils.h):
// - Native (default): Fuser's own CUDA VMM + IPC implementation; maintained.
// - PyTorch (Nccl, Nvshmem, Cuda): Use PyTorch's symmetric memory
// (torch.distributed._symmetric_memory) with the chosen transport backend.
// Select via NVFUSER_ENABLE=symmetric_memory_backend(pytorch_nccl|pytorch_nvshmem|pytorch_cuda).
// Native remains the default when the option is not set.
class SymmetricTensor {
public:
// Wrap pre-allocated symmetric tensor (must use allocate())
Expand Down Expand Up @@ -79,6 +84,10 @@ class SymmetricTensor {
int peer_fd_{-1};
bool is_contiguous_view_setup_ = false;
at::Tensor contiguous_view_;
#ifdef NVFUSER_DISTRIBUTED
// When set, remote/multicast APIs delegate to PyTorch symmetric memory.
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> py_symm_handle_;
#endif
};

} // namespace nvfuser
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"fast_math", EnableOption::FastMath},
{"p2p_protocol", EnableOption::P2pProtocol},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"symmetric_memory_backend", EnableOption::SymmetricMemoryBackend},
{"parallel_serde", EnableOption::ParallelSerde},
};
return available_options;
Expand Down
2 changes: 2 additions & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ enum class EnableOption {
P2pProtocol, //! Prescribe P2P protocol: put|get
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
SymmetricMemoryBackend, //! Prescribe symmetric memory backend:
//! native|pytorch_nccl|pytorch_nvshmem|pytorch_cuda
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
EndOfOption //! Placeholder for counting the number of elements
};
Expand Down
Loading
Loading