From 14fd212fa6ed76fa541f2e4d405675d6dc61e353 Mon Sep 17 00:00:00 2001 From: Sai Vishal Pothula Date: Fri, 27 Feb 2026 04:40:02 +0200 Subject: [PATCH] Initial implementation of symmetric memory backend for PyTorch --- csrc/multidevice/communicator.cpp | 16 ++ csrc/multidevice/communicator.h | 13 ++ csrc/multidevice/ipc_utils.cpp | 18 ++ csrc/multidevice/ipc_utils.h | 13 ++ csrc/multidevice/symmetric_tensor.cpp | 163 +++++++++++++++++- csrc/multidevice/symmetric_tensor.h | 21 ++- csrc/options.cpp | 1 + csrc/options.h | 2 + fbuild.sh | 24 +++ .../cpp/test_multidevice_symmetric_tensor.cpp | 108 ++++++++++++ 10 files changed, 372 insertions(+), 7 deletions(-) create mode 100755 fbuild.sh diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index dbd65ba4610..7531f7ff186 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -424,4 +424,20 @@ void Communicator::barrier(std::optional backend) { getWorld(backend)->barrier(options)->wait(); } +#ifdef NVFUSER_DISTRIBUTED +c10::intrusive_ptr Communicator::getStore() const { + return c10::intrusive_ptr(store_); +} + +c10::intrusive_ptr Communicator::getWorldBackendIntrusivePtr( + std::optional backend) { + std::vector all_ranks(size_); + std::iota(all_ranks.begin(), all_ranks.end(), 0); + CommunicatorBackend b = backend.value_or(default_backend_); + std::string team_key = getTeamKey(all_ranks, b); + (void)getBackendForTeam(all_ranks, backend, ""); + return backends_.at(team_key); +} +#endif + } // namespace nvfuser diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index b56e6fee3aa..45a7f9e4f4f 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -13,6 +13,7 @@ #ifdef NVFUSER_DISTRIBUTED #include +#include #include #include #else @@ -124,6 +125,18 @@ class NVF_API Communicator { return store_.get(); } +#ifdef NVFUSER_DISTRIBUTED + // Returns the store as an intrusive_ptr for use with PyTorch symmetric + // memory (c10d::symmetric_memory::set_group_info). + c10::intrusive_ptr getStore() const; + + // Returns the world backend as an intrusive_ptr so it can be registered with + // c10d::register_process_group (e.g. for PyTorch symmetric memory NCCL + // rendezvous, which resolves the group by name). + c10::intrusive_ptr getWorldBackendIntrusivePtr( + std::optional backend = std::nullopt); +#endif + private: Communicator( CommunicatorBackend backend = comm_backend_default, diff --git a/csrc/multidevice/ipc_utils.cpp b/csrc/multidevice/ipc_utils.cpp index 656b4ee5e24..01bdf949044 100644 --- a/csrc/multidevice/ipc_utils.cpp +++ b/csrc/multidevice/ipc_utils.cpp @@ -191,4 +191,22 @@ MulticastProtocol getMulticastProtocol() { return MulticastProtocol::BatchMemcpy; } +SymmetricMemoryBackend getSymmetricMemoryBackend() { + if (isOptionEnabled(EnableOption::SymmetricMemoryBackend)) { + if (hasEnableOptionArgument( + EnableOption::SymmetricMemoryBackend, "pytorch_nccl")) { + return SymmetricMemoryBackend::PyTorchNccl; + } + if (hasEnableOptionArgument( + EnableOption::SymmetricMemoryBackend, "pytorch_nvshmem")) { + return SymmetricMemoryBackend::PyTorchNvshmem; + } + if (hasEnableOptionArgument( + EnableOption::SymmetricMemoryBackend, "pytorch_cuda")) { + return SymmetricMemoryBackend::PyTorchCuda; + } + } + return SymmetricMemoryBackend::Native; +} + } // namespace nvfuser diff --git a/csrc/multidevice/ipc_utils.h b/csrc/multidevice/ipc_utils.h index bac466d74f8..0cfd6586e47 100644 --- a/csrc/multidevice/ipc_utils.h +++ b/csrc/multidevice/ipc_utils.h @@ -33,6 +33,19 @@ enum class MulticastProtocol { Memcpy, Multimem, BatchMemcpy }; MulticastProtocol getMulticastProtocol(); +// Backend for symmetric memory allocation and rendezvous. +// Native: Fuser's own CUDA VMM + IPC implementation (default, maintained). +// PyTorch*: Use PyTorch's symmetric memory (torch.distributed._symmetric_memory) +// with the given transport backend (Nccl, Nvshmem, or Cuda). +enum class SymmetricMemoryBackend { + Native, + PyTorchNccl, + PyTorchNvshmem, + PyTorchCuda, +}; + +SymmetricMemoryBackend getSymmetricMemoryBackend(); + // Creates a listening Unix domain socket bound to path. // If path starts with '@', it uses the abstract namespace (replaced with \0). // Returns the socket file descriptor. diff --git a/csrc/multidevice/symmetric_tensor.cpp b/csrc/multidevice/symmetric_tensor.cpp index 902ec5a3a32..6f8c3330515 100644 --- a/csrc/multidevice/symmetric_tensor.cpp +++ b/csrc/multidevice/symmetric_tensor.cpp @@ -7,6 +7,7 @@ // clang-format on #include "multidevice/symmetric_tensor.h" +#include #include #include "cuda_utils.h" @@ -15,10 +16,63 @@ #include "multidevice/ipc_utils.h" #include "multidevice/utils.h" +#ifdef NVFUSER_DISTRIBUTED +#include +#endif + namespace nvfuser { namespace { +#ifdef NVFUSER_DISTRIBUTED +const char* kPyTorchSymmMemGroupName = "nvfuser_symm"; + +// Cache: tensor storage data ptr -> PyTorch SymmetricMemory handle from +// rendezvous. Used so SymmetricTensor(tensor) can recover the handle. +std::unordered_map>& +getPySymmHandleCache() { + static std::unordered_map< + void*, + c10::intrusive_ptr> + cache; + return cache; +} + +std::mutex& getPySymmHandleCacheMutex() { + static std::mutex m; + return m; +} + +void ensurePyTorchSymmMemBackend(SymmetricMemoryBackend backend) { + static std::once_flag once; + std::call_once(once, [backend]() { + const char* name = nullptr; + switch (backend) { + case SymmetricMemoryBackend::PyTorchNccl: + name = "NCCL"; + break; + case SymmetricMemoryBackend::PyTorchNvshmem: + name = "NVSHMEM"; + break; + case SymmetricMemoryBackend::PyTorchCuda: + name = "CUDA"; + break; + default: + NVF_ERROR(false, "Unexpected PyTorch symmetric memory backend"); + } + c10d::symmetric_memory::set_backend(name); + Communicator& comm = Communicator::getInstance(); + NVF_CHECK(comm.is_available(), "Communicator not available for symmetric memory"); + c10d::symmetric_memory::set_group_info( + kPyTorchSymmMemGroupName, + static_cast(comm.deviceId()), + static_cast(comm.size()), + comm.getStore()); + }); +} +#endif + + // Returns the allocation granularity for symmetric memory. // - query_mcast_granularity: if true, considers multicast granularity // - query_mcast_recommended_granularity: if true, uses recommended (larger) @@ -88,6 +142,48 @@ at::Tensor SymmetricTensor::allocate( at::IntArrayRef sizes, at::ScalarType dtype, at::Device device) { + SymmetricMemoryBackend backend = getSymmetricMemoryBackend(); + +#ifdef NVFUSER_DISTRIBUTED + if (backend != SymmetricMemoryBackend::Native) { + ensurePyTorchSymmMemBackend(backend); + std::vector strides(sizes.size()); + strides.back() = 1; + for (int64_t i = (int64_t)strides.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + // NCCLSymmetricMemoryAllocator::alloc must not be called with a group_name. + c10::optional alloc_group_name = + (backend == SymmetricMemoryBackend::PyTorchNccl) + ? c10::nullopt + : c10::optional(kPyTorchSymmMemGroupName); + at::Tensor tensor = c10d::symmetric_memory::empty_strided_p2p( + sizes, + strides, + dtype, + device, + alloc_group_name, + c10::nullopt); + c10::intrusive_ptr handle = + c10d::symmetric_memory::rendezvous( + tensor, c10::optional(kPyTorchSymmMemGroupName)); + void* key = tensor.storage().data_ptr().get(); + { + std::lock_guard lock(getPySymmHandleCacheMutex()); + getPySymmHandleCache()[key] = handle; + } + return tensor; + } +#else + if (backend != SymmetricMemoryBackend::Native) { + NVF_ERROR( + false, + "PyTorch symmetric memory backend requires a build with " + "NVFUSER_DISTRIBUTED. Use NVFUSER_ENABLE=symmetric_memory_backend(native) " + "or do not set symmetric_memory_backend."); + } +#endif + int is_vmm_supported; NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( &is_vmm_supported, @@ -212,6 +308,28 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) "Expected CUDA tensor, got: ", local_tensor.device()); +#ifdef NVFUSER_DISTRIBUTED + { + void* key = local_tensor.storage().data_ptr().get(); + std::lock_guard lock(getPySymmHandleCacheMutex()); + auto& cache = getPySymmHandleCache(); + auto it = cache.find(key); + if (it != cache.end()) { + py_symm_handle_ = std::move(it->second); + cache.erase(it); + world_size_ = py_symm_handle_->get_world_size(); + my_device_id_ = py_symm_handle_->get_rank(); + requested_size_ = local_tensor.numel() * local_tensor.element_size(); + are_remote_tensors_setup_ = true; // PyTorch rendezvous already set up + if (py_symm_handle_->has_multicast_support()) { + is_multicast_setup_ = true; + mc_ptr_ = py_symm_handle_->get_multicast_ptr(); + } + return; + } + } +#endif + std::string error = SymmetricTensor::validate(local_tensor); NVF_CHECK(error.empty(), "Invalid symmetric allocation: ", error); @@ -253,6 +371,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) } SymmetricTensor::~SymmetricTensor() { +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + return; // PyTorch backend: no native VMM cleanup + } +#endif #if (CUDA_VERSION >= 13000) if (is_multicast_setup_) { if (mc_base_ptr_) { @@ -302,6 +425,11 @@ void SymmetricTensor::setupRemoteHandles(const std::string& tag) { if (are_remote_tensors_setup_ == true) { return; } +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + return; // PyTorch backend: rendezvous already established remote access + } +#endif Communicator& comm = Communicator::getInstance(); CUmemGenericAllocationHandle local_handle = alloc_handles_[my_device_id_]; CUdeviceptr local_ptr = remote_ptrs_[my_device_id_]; @@ -379,6 +507,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { return local_tensor_; } +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + return py_symm_handle_->get_remote_tensor( + rank, local_tensor_.sizes(), local_tensor_.scalar_type()); + } +#endif + NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup"); return at::from_blob( reinterpret_cast(remote_ptrs_[rank]), @@ -390,6 +525,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { } void* SymmetricTensor::multicastPtr() const { +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + return py_symm_handle_->has_multicast_support() + ? py_symm_handle_->get_multicast_ptr() + : nullptr; + } +#endif NVF_CHECK(is_multicast_setup_, "Multicast not setup"); return mc_ptr_; } @@ -398,7 +540,14 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) { if (is_contiguous_view_setup_) { return; } - +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + NVF_ERROR( + false, + "Contiguous view is not yet supported for PyTorch symmetric memory backend. " + "Use native backend for SymmetricContiguousView."); + } +#endif NVF_CHECK( are_remote_tensors_setup_ == true, "Remote tensors must be setup before setupContiguousView"); @@ -462,6 +611,13 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) { } at::Tensor SymmetricTensor::getContiguousView() const { +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + NVF_ERROR( + false, + "Contiguous view is not yet supported for PyTorch symmetric memory backend."); + } +#endif NVF_CHECK(is_contiguous_view_setup_, "Contiguous view not setup"); return contiguous_view_; } @@ -469,6 +625,11 @@ at::Tensor SymmetricTensor::getContiguousView() const { void SymmetricTensor::setupMulticast( int64_t exporter_rank, const std::string& tag) { +#ifdef NVFUSER_DISTRIBUTED + if (py_symm_handle_) { + return; // PyTorch backend: multicast handled by backend if supported + } +#endif #if (CUDA_VERSION >= 13000) if (is_multicast_setup_) { return; diff --git a/csrc/multidevice/symmetric_tensor.h b/csrc/multidevice/symmetric_tensor.h index 5608153e0ce..c928a7d5469 100644 --- a/csrc/multidevice/symmetric_tensor.h +++ b/csrc/multidevice/symmetric_tensor.h @@ -10,6 +10,10 @@ #include #include +#ifdef NVFUSER_DISTRIBUTED +#include +#endif + namespace nvfuser { // SymmetricTensor wraps a local symmetric memory allocation and enables: @@ -18,13 +22,14 @@ namespace nvfuser { // - Contiguous view creation across all ranks // // Design: Decouples local allocation from IPC handle exchange for better -// interoperability and support for pre-allocated user buffers +// interoperability and support for pre-allocated user buffers. // -// TODO: Long term plan is to integrate pytorch's native symmetric memory as a -// possible backend. One important reason to use pytorch's allocator is to use -// pytorch's memory pool to let the framework own the memory stack and not -// further fragment the memory. On the other hand, having our own implementation -// allows us to experiment more advanced features like contigous view creation. +// Backends (see SymmetricMemoryBackend in ipc_utils.h): +// - Native (default): Fuser's own CUDA VMM + IPC implementation; maintained. +// - PyTorch (Nccl, Nvshmem, Cuda): Use PyTorch's symmetric memory +// (torch.distributed._symmetric_memory) with the chosen transport backend. +// Select via NVFUSER_ENABLE=symmetric_memory_backend(pytorch_nccl|pytorch_nvshmem|pytorch_cuda). +// Native remains the default when the option is not set. class SymmetricTensor { public: // Wrap pre-allocated symmetric tensor (must use allocate()) @@ -79,6 +84,10 @@ class SymmetricTensor { int peer_fd_{-1}; bool is_contiguous_view_setup_ = false; at::Tensor contiguous_view_; +#ifdef NVFUSER_DISTRIBUTED + // When set, remote/multicast APIs delegate to PyTorch symmetric memory. + c10::intrusive_ptr py_symm_handle_; +#endif }; } // namespace nvfuser diff --git a/csrc/options.cpp b/csrc/options.cpp index 6d587e35afd..9197b43ee3e 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -183,6 +183,7 @@ const std::unordered_map& getEnableOptions() { {"fast_math", EnableOption::FastMath}, {"p2p_protocol", EnableOption::P2pProtocol}, {"multicast_protocol", EnableOption::MulticastProtocol}, + {"symmetric_memory_backend", EnableOption::SymmetricMemoryBackend}, {"parallel_serde", EnableOption::ParallelSerde}, }; return available_options; diff --git a/csrc/options.h b/csrc/options.h index 4c72c757460..bda66b1f526 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -128,6 +128,8 @@ enum class EnableOption { P2pProtocol, //! Prescribe P2P protocol: put|get MulticastProtocol, //! Prescribe multicast protocol: //! memcpy|multimem|batch_memcpy + SymmetricMemoryBackend, //! Prescribe symmetric memory backend: + //! native|pytorch_nccl|pytorch_nvshmem|pytorch_cuda ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel EndOfOption //! Placeholder for counting the number of elements }; diff --git a/fbuild.sh b/fbuild.sh new file mode 100755 index 00000000000..e16a2e2cdd9 --- /dev/null +++ b/fbuild.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +export CC=clang-20 +export CXX=clang++-20 +export LDFLAGS="-fuse-ld=mold" + +export NVFUSER_BUILD_ENABLE_PCH + +export UCC_HOME="/opt/hpcx/ucc" +export UCC_DIR="/opt/hpcx/ucc/lib/cmake/ucc" +export UCX_HOME="/opt/hpcx/ucx" +export UCX_DIR="/opt/hpcx/ucx/lib/cmake/ucx" + +# export TORCH_CUDA_ARCH_LIST="9.0" + +export NVFUSER_BUILD_WITH_UCC=1 +export NVFUSER_BUILD_INSTALL_DIR=$BUILD_DIRECTORY/nvfuser +export NVFUSER_BUILD_DIR=$BUILD_DIRECTORY + +# Enable debug mode, leave empty for non-debug compilation +export NVFUSER_BUILD_BUILD_TYPE=Debug +export RUN_CMAKE="" + +pip install -v -e ./python --no-build-isolation diff --git a/tests/cpp/test_multidevice_symmetric_tensor.cpp b/tests/cpp/test_multidevice_symmetric_tensor.cpp index 2e4b5e66767..98af2bd6fbd 100644 --- a/tests/cpp/test_multidevice_symmetric_tensor.cpp +++ b/tests/cpp/test_multidevice_symmetric_tensor.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include "multidevice/ipc_utils.h" #include "multidevice/symmetric_tensor.h" #include "tests/cpp/multidevice.h" @@ -12,6 +13,68 @@ namespace nvfuser { using SymmetricTensorTest = MultiDeviceTest; +// ----------------------------------------------------------------------------- +// Symmetric memory backend and option tests +// ----------------------------------------------------------------------------- + +TEST_F(SymmetricTensorTest, GetSymmetricMemoryBackend_ReturnsValidBackend) { + SymmetricMemoryBackend backend = getSymmetricMemoryBackend(); + EXPECT_TRUE( + backend == SymmetricMemoryBackend::Native || + backend == SymmetricMemoryBackend::PyTorchNccl || + backend == SymmetricMemoryBackend::PyTorchNvshmem || + backend == SymmetricMemoryBackend::PyTorchCuda) + << "getSymmetricMemoryBackend() returned an invalid backend value"; +} + +TEST_F(SymmetricTensorTest, Validate_RejectsNormalCudaTensor) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device (Native allocate needs VMM)"; + } + // With Native backend, allocate() uses VMM; with PyTorch backend we skip + // because validate() is only used for Native path (PyTorch tensors come from + // cache in constructor). + if (getSymmetricMemoryBackend() != SymmetricMemoryBackend::Native) { + GTEST_SKIP() << "Validate test applies to Native backend only"; + } + // Allocate a normal (non-symmetric) CUDA tensor + at::Tensor normal_tensor = at::empty( + {64, 64}, + at::TensorOptions().dtype(at::kFloat).device(communicator_->device())); + std::string error = SymmetricTensor::validate(normal_tensor); + EXPECT_FALSE(error.empty()) + << "SymmetricTensor::validate() should reject a normal CUDA tensor"; +} + +TEST_F(SymmetricTensorTest, Validate_AcceptsSymmetricAllocation) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + if (getSymmetricMemoryBackend() != SymmetricMemoryBackend::Native) { + GTEST_SKIP() << "validate() for allocate() output is defined for Native backend only"; + } + at::Tensor sym_tensor = SymmetricTensor::allocate( + {128, 128}, at::ScalarType::Float, communicator_->device()); + std::string error = SymmetricTensor::validate(sym_tensor); + EXPECT_TRUE(error.empty()) + << "SymmetricTensor::validate() should accept tensor from allocate(); got: " + << error; +} + +TEST_F(SymmetricTensorTest, Constructor_ThrowsOnInvalidTensor) { + at::Tensor normal_tensor = at::empty( + {8, 8}, + at::TensorOptions().dtype(at::kFloat).device(communicator_->device())); + EXPECT_THROW( + { SymmetricTensor sym_tensor(normal_tensor); }, + c10::Error) + << "SymmetricTensor constructor should throw when given a non-symmetric tensor"; +} + +// ----------------------------------------------------------------------------- +// Backend-agnostic and Native backend correctness (allocate + remote access) +// ----------------------------------------------------------------------------- + TEST_F(SymmetricTensorTest, BasicAllocation) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; @@ -54,6 +117,51 @@ TEST_F(SymmetricTensorTest, BasicAllocation) { } } +// Same remote-access correctness as BasicAllocation but only runs when +// PyTorch symmetric memory backend is selected (NVFUSER_ENABLE= +// symmetric_memory_backend(pytorch_nccl|pytorch_nvshmem|pytorch_cuda)). +// Run with e.g. NVFUSER_ENABLE=symmetric_memory_backend(pytorch_nccl) to +// exercise the PyTorch path. +TEST_F(SymmetricTensorTest, PyTorchBackend_RemoteAccessCorrectness) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + SymmetricMemoryBackend backend = getSymmetricMemoryBackend(); + if (backend == SymmetricMemoryBackend::Native) { + GTEST_SKIP() + << "PyTorch backend not selected; set NVFUSER_ENABLE=symmetric_memory_backend(pytorch_nccl) to run"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + + at::Tensor local_tensor = SymmetricTensor::allocate( + {256, 512}, at::ScalarType::Float, communicator_->device()); + SymmetricTensor sym_tensor(local_tensor); + + EXPECT_TRUE(local_tensor.is_cuda()); + EXPECT_EQ(local_tensor.numel(), 256 * 512); + + float local_value = static_cast(rank + 200); + local_tensor.fill_(local_value); + + sym_tensor.setupRemoteHandles(); + + for (int64_t peer_rank = 0; peer_rank < world_size; ++peer_rank) { + void* peer_ptr = sym_tensor.remoteTensor(peer_rank).data_ptr(); + EXPECT_NE(peer_ptr, nullptr); + + float peer_value; + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + &peer_value, peer_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + + float expected_value = static_cast(peer_rank + 200); + EXPECT_FLOAT_EQ(peer_value, expected_value) + << "Rank " << rank << " reading from rank " << peer_rank + << " (PyTorch backend)"; + } +} + TEST_F(SymmetricTensorTest, PreallocatedTensor) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device";