diff --git a/CMakeLists.txt b/CMakeLists.txt index 44d9c964342..296be226ae6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ set(NVFUSER_CUTLASS "${NVFUSER_ROOT}/cutlass") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) +option(NVFUSER_STANDALONE_BUILD_WITH_NIXL "" OFF) option(NVFUSER_EXPLICIT_ERROR_CHECK "" OFF) option(NVFUSER_ENABLE_DEPENDENCY_REPORT "Enable Python-based dependency reporting and log capture" ON) @@ -248,6 +249,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp ${NVFUSER_SRCS_DIR}/multidevice/executor.cpp + ${NVFUSER_SRCS_DIR}/multidevice/nixl.cpp ${NVFUSER_SRCS_DIR}/multidevice/execution_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/propagation.cpp ${NVFUSER_SRCS_DIR}/multidevice/resharding.cpp @@ -583,6 +585,37 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + # User may need to set NIXL_PREFIX to the NIXL install directory. + find_path(NIXL_INCLUDE_DIR nixl.h + HINTS $ENV{NIXL_PREFIX}/include ENV CPATH + ) + find_library(NIXL_LIBRARY nixl + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + find_library(NIXL_BUILD_LIBRARY nixl_build + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + + if(NOT NIXL_INCLUDE_DIR OR NOT NIXL_LIBRARY) + message(FATAL_ERROR "NIXL not found. Set NIXL_PREFIX to the NIXL install directory.") + endif() + + message(STATUS "Found NIXL: ${NIXL_LIBRARY} (include: ${NIXL_INCLUDE_DIR})") + if(NIXL_BUILD_LIBRARY) + message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") + endif() + + add_library(__nvfuser_nixl INTERFACE) + target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) + if(NIXL_BUILD_LIBRARY) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) + endif() + target_link_libraries(codegen_internal PRIVATE __nvfuser_nixl) + target_compile_definitions(codegen_internal PRIVATE USE_NIXL) +endif() + add_dependencies(codegen_internal flatc build_flatbuffer_config) # installing nvfuser headers @@ -1031,6 +1064,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_nixl.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -1332,6 +1366,11 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") +message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_NIXL : ${NVFUSER_STANDALONE_BUILD_WITH_NIXL}") +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") + message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") +endif() message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index dbd65ba4610..c021a129670 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -41,6 +41,9 @@ std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb) { case CommunicatorBackend::kCuda: out << "CUDA"; break; + case CommunicatorBackend::kNixl: + out << "NIXL"; + break; } return out; } @@ -183,7 +186,8 @@ Communicator::Communicator( master_port_( c10d::TCPStoreOptions::kDefaultPort + 42), // to avoid collision ucc_available_(false), - nccl_available_(false) { + nccl_available_(false), + nixl_available_(false) { if (isOptionDisabled(DisableOption::Multidevice)) { TORCH_WARN( "Multi-device support is disabled. All communication operations will " @@ -236,6 +240,10 @@ Communicator::Communicator( #ifdef USE_C10D_NCCL nccl_available_ = true; #endif + +#ifdef USE_NIXL + nixl_available_ = true; +#endif } namespace { diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index b56e6fee3aa..25fed9eebfc 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,6 +11,8 @@ #include #include +#include + #ifdef NVFUSER_DISTRIBUTED #include #include @@ -116,13 +118,15 @@ class NVF_API Communicator { return ucc_available_; } else if (backend == CommunicatorBackend::kNccl) { return nccl_available_; + } else if (backend == CommunicatorBackend::kNixl) { + return nixl_available_; } return false; } c10d::TCPStore* getTcpStore() { return store_.get(); - } +} private: Communicator( @@ -149,10 +153,12 @@ class NVF_API Communicator { int master_port_; bool ucc_available_; bool nccl_available_; + bool nixl_available_; // stores the world's store used for the backend init c10::intrusive_ptr store_; // cache for the created backends. The keys are strings generated from Teams std::unordered_map> backends_; }; + } // namespace nvfuser diff --git a/csrc/multidevice/multidevice.h b/csrc/multidevice/multidevice.h index 288a89fe952..7915f5e3d92 100644 --- a/csrc/multidevice/multidevice.h +++ b/csrc/multidevice/multidevice.h @@ -19,5 +19,5 @@ using DeviceType = c10::Device; using Team = std::vector; // Supported backends. -enum class CommunicatorBackend { kNccl, kUcc, kCuda }; +enum class CommunicatorBackend { kNccl, kUcc, kCuda, kNixl }; } // namespace nvfuser diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp new file mode 100644 index 00000000000..253298bc9bc --- /dev/null +++ b/csrc/multidevice/nixl.cpp @@ -0,0 +1,474 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include "multidevice/nixl.h" +#include "exceptions.h" + +#include +#include +#include +#include +#include + +#ifdef USE_NIXL +#include +#endif +namespace nvfuser { + +// =================================================================== +// NixlTransferHandle +// =================================================================== + +class NixlTransferHandleImpl { + public: +#ifdef USE_NIXL + explicit NixlTransferHandleImpl(nixlAgent* agent) : agent(agent) {} + nixlAgent* agent; + nixlXferReqH* xfer_handle = nullptr; + + ~NixlTransferHandleImpl() { + if (xfer_handle) { + agent->releaseXferReq(xfer_handle); + } + } +#endif + bool prepared = false; + bool posted = false; +}; + +NixlTransferHandle::NixlTransferHandle() = default; +NixlTransferHandle::~NixlTransferHandle() = default; +NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = + default; +NixlTransferHandle& NixlTransferHandle::operator=( + NixlTransferHandle&&) noexcept = default; + +bool NixlTransferHandle::isValid() const { + if (!impl_) { + return false; + } +#ifdef USE_NIXL + return impl_->prepared; +#else + return false; +#endif +} + +// =================================================================== +// Tensor validation and descriptor helpers +// =================================================================== + +namespace { + +void validateCudaTensors(const std::vector& tensors) { + NVF_ERROR(!tensors.empty(), "Tensor list must not be empty"); + for (const auto& t : tensors) { + NVF_ERROR(t.is_cuda(), "All tensors must be CUDA tensors"); + NVF_ERROR(t.is_contiguous(), "All tensors must be contiguous"); + } +} + +#ifdef USE_NIXL + +nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { + nixl_reg_dlist_t dlist(VRAM_SEG); + for (const auto& t : tensors) { + dlist.addDesc( + {reinterpret_cast(t.data_ptr()), + static_cast(t.numel()) * t.element_size(), + static_cast(t.device().index())}); + } + return dlist; +} + +nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { + nixl_xfer_dlist_t dlist(VRAM_SEG); + for (const auto& desc : descs) { + dlist.addDesc({desc.addr, desc.size, desc.dev}); + } + return dlist; +} + +nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { + switch (op) { + case NixlXferOp::kRead: + return NIXL_READ; + case NixlXferOp::kWrite: + return NIXL_WRITE; + } + NVF_THROW("Invalid NIXL transfer operation: ", static_cast(op)); +} + +#endif + +} // namespace + +// =================================================================== +// NixlBackend::Impl +// =================================================================== + +#ifdef USE_NIXL + +class NixlBackend::Impl { + public: + static std::unique_ptr create(Communicator& communicator); + ~Impl(); + + void registerTensors(const std::vector& tensors); + void deregisterTensors(const std::vector& tensors); + void exchangeMetadata(); + + NixlTransferHandle prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, + int64_t remote_rank, + NixlXferOp op); + + void postTransfer(NixlTransferHandle& handle); + NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + void waitTransfer(NixlTransferHandle& handle); + + private: + explicit Impl(Communicator& communicator); + inline std::string getAgentName(int64_t rank); + + std::unique_ptr agent_; + nixlBackendH* backend_ = nullptr; + Communicator& communicator_; + bool metadata_exchanged_ = false; +}; + +// ------------------------------------------------------------------- +// Construction / destruction +// ------------------------------------------------------------------- + +NixlBackend::Impl::Impl(Communicator& communicator) + : communicator_(communicator) {} + +std::unique_ptr NixlBackend::Impl::create( + Communicator& communicator) { + std::unique_ptr impl(new Impl(communicator)); + + std::string agent_name = impl->getAgentName(communicator.deviceId()); + nixlAgentConfig cfg(false); + impl->agent_ = std::make_unique(agent_name, cfg); + + nixl_b_params_t params; + nixl_status_t status = + impl->agent_->createBackend("UCX", params, impl->backend_); + if (status != NIXL_SUCCESS) { + impl->agent_.reset(); + NVF_THROW("Failed to create UCX backend for NIXL agent"); + } + + // Probe: verify that VRAM (CUDA GPU memory) is actually usable with + // the UCX backend. Some UCX installations lack CUDA support, causing + // registerMem to silently misclassify VRAM as host memory. We detect + // this by registering a small buffer and asking NIXL to prepare a + // local descriptor list for VRAM -- if no backend claims VRAM, the + // probe fails and we mark the backend as unavailable. + { + constexpr int64_t kProbeBytes = 1; + auto probe = at::empty( + {kProbeBytes}, + at::TensorOptions().dtype(at::kByte).device( + at::kCUDA, communicator.deviceId())); + size_t nbytes = static_cast(probe.nbytes()); + uintptr_t addr = reinterpret_cast(probe.data_ptr()); + uint32_t dev_idx = static_cast(probe.device().index()); + + NVF_ERROR(nbytes > 0, "NIXL probe: unexpected zero-byte tensor"); + NVF_ERROR(addr != 0, "NIXL probe: null data pointer"); + + nixl_reg_dlist_t reg_dlist(VRAM_SEG); + reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); + + nixl_status_t reg_status = impl->agent_->registerMem(reg_dlist); + if (reg_status != NIXL_SUCCESS) { + return nullptr; + } + + nixl_xfer_dlist_t xfer_dlist(VRAM_SEG); + xfer_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); + + nixlDlistH* dlist_handle = nullptr; + nixl_status_t prep_status = + impl->agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + + if (dlist_handle) { + impl->agent_->releasedDlistH(dlist_handle); + } + impl->agent_->deregisterMem(reg_dlist); + + if (prep_status != NIXL_SUCCESS) { + return nullptr; + } + } + + return impl; +} + +NixlBackend::Impl::~Impl() = default; + +std::string NixlBackend::Impl::getAgentName(int64_t rank) { + return "rank_" + std::to_string(rank); +} + +// ------------------------------------------------------------------- +// Memory registration +// ------------------------------------------------------------------- + +// TODO - consider adding RAII wrapper +void NixlBackend::Impl::registerTensors( + const std::vector& tensors) { + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->registerMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL registerMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +} + +void NixlBackend::Impl::deregisterTensors( + const std::vector& tensors) { + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->deregisterMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL deregisterMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +} + +// ------------------------------------------------------------------- +// Metadata exchange +// ------------------------------------------------------------------- + +void NixlBackend::Impl::exchangeMetadata() { + nixl_blob_t local_md; + nixl_status_t md_status = agent_->getLocalMD(local_md); + NVF_ERROR( + md_status == NIXL_SUCCESS, + "NIXL getLocalMD failed with status ", + static_cast(md_status)); + + auto* store = communicator_.getTcpStore(); + const auto my_rank = communicator_.deviceId(); + const auto world_size = communicator_.size(); + + std::string md_key_prefix = "nixl_agent_md_rank_"; + store->set( + md_key_prefix + std::to_string(my_rank), + std::vector(local_md.begin(), local_md.end())); + + for (int64_t rank = 0; rank < world_size; ++rank) { + if (rank == my_rank) { + continue; + } + // Fetch & load MD + auto bytes = store->get(md_key_prefix + std::to_string(rank)); + nixl_blob_t remote_md(bytes.begin(), bytes.end()); + std::string remote_agent_name; + nixl_status_t status = agent_->loadRemoteMD(remote_md, remote_agent_name); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL loadRemoteMD failed for rank ", + rank, + " with status ", + static_cast(status)); + } + + // Barrier before deleting keys so no rank reads a deleted key. + communicator_.barrier(); + + store->deleteKey(md_key_prefix + std::to_string(my_rank)); + metadata_exchanged_ = true; +} + +// ------------------------------------------------------------------- +// Transfer preparation +// ------------------------------------------------------------------- + +// Prepare a transfer between local and remote tensor pairs. +// +// NIXL pairs local_tensors[i] with remote_tensors[i]. The direction +// depends on `op`: +// kRead -- data flows from remote into local +// kWrite -- data flows from local into remote +// +NixlTransferHandle NixlBackend::Impl::prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, + int64_t remote_rank, + NixlXferOp op) { + NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); + NVF_ERROR( + local_descs.size() == remote_descs.size(), + "Local and remote tensor lists must have the same size. Got ", + local_descs.size(), + " vs ", + remote_descs.size()); + + std::string remote_agent_name = getAgentName(remote_rank); + + nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); + nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); + + auto impl = std::make_unique(agent_.get()); + nixl_status_t status = agent_->createXferReq( + toNixlXferOp(op), + local_dlist, + remote_dlist, + remote_agent_name, + impl->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL createXferReq failed with status ", + static_cast(status)); + + impl->prepared = true; + NixlTransferHandle handle; + handle.impl_ = std::move(impl); + return handle; +} + +// ------------------------------------------------------------------- +// Transfer posting +// ------------------------------------------------------------------- + +void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { + NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); + NVF_ERROR( + !handle.impl_->posted, + "Transfer already posted. Wait for completion before re-posting."); + + nixl_status_t status = agent_->postXferReq(handle.impl_->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS || status == NIXL_IN_PROG, + "NIXL postXferReq failed with status ", + static_cast(status)); + + handle.impl_->posted = true; +} + +// ------------------------------------------------------------------- +// Transfer status / wait +// ------------------------------------------------------------------- + +NixlXferStatus NixlBackend::Impl::getTransferStatus( + const NixlTransferHandle& handle) const { + NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + nixl_status_t status = agent_->getXferStatus(handle.impl_->xfer_handle); + switch (status) { + case NIXL_SUCCESS: + return NixlXferStatus::kDone; + case NIXL_IN_PROG: + return NixlXferStatus::kInProgress; + default: + return NixlXferStatus::kError; + } +} + +void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { + NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + // TODO - check this spin loop + NixlXferStatus xfer_status; + do { + xfer_status = getTransferStatus(handle); + NVF_ERROR( + xfer_status != NixlXferStatus::kError, + "NIXL transfer completed with an error"); + } while (xfer_status == NixlXferStatus::kInProgress); + + handle.impl_->posted = false; +} + +#else // !USE_NIXL + +class NixlBackend::Impl {}; + +#endif // USE_NIXL + +// =================================================================== +// NixlBackend singleton + public API +// =================================================================== + +NixlBackend::NixlBackend() { +#ifdef USE_NIXL + impl_ = Impl::create(Communicator::getInstance()); +#endif +} + +NixlBackend& NixlBackend::getInstance() { + static auto* instance = new NixlBackend(); + NVF_CHECK(!instance->cleaned_up_, "NIXL backend has been cleaned up"); + return *instance; +} + +void NixlBackend::cleanup() { + cleaned_up_ = true; + impl_.reset(); +} + +bool NixlBackend::isAvailable() const { + return impl_ != nullptr; +} + +void NixlBackend::registerTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + impl_->registerTensors(tensors); +} + +void NixlBackend::deregisterTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + impl_->deregisterTensors(tensors); +} + +void NixlBackend::exchangeMetadata() { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + impl_->exchangeMetadata(); +} + +NixlTransferHandle NixlBackend::prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, + int64_t remote_rank, + NixlXferOp op) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + return impl_->prepareTransfer( + local_descs, remote_descs, remote_rank, op); +} + +void NixlBackend::postTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + impl_->postTransfer(handle); +} + +NixlXferStatus NixlBackend::getTransferStatus( + const NixlTransferHandle& handle) const { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + return impl_->getTransferStatus(handle); +} + +void NixlBackend::waitTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); + impl_->waitTransfer(handle); +} + + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h new file mode 100644 index 00000000000..e9de9c8384b --- /dev/null +++ b/csrc/multidevice/nixl.h @@ -0,0 +1,221 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include +#include + +#include "exceptions.h" +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +// Transfer direction. NIXL uses a one-sided model: +// Read = pull remote data into local buffers +// Write = push local data into remote buffers +enum class NixlXferOp { + kRead, + kWrite, +}; + +enum class NixlXferStatus { + kDone, + kInProgress, + kError, +}; + +// ------------------------------------------------------------------ +// Todo - those functions should be moved to a more global file +// Helper functions for serializing and deserializing tensors descriptors for TCP store +struct TensorDesc { + uintptr_t addr; + size_t size; + uint32_t dev; +}; +static_assert(std::is_trivially_copyable_v, + "TensorDesc must be trivially copyable for serialization"); + +inline TensorDesc toTensorDesc(const at::Tensor& tensor) { + return { + .addr = reinterpret_cast(tensor.data_ptr()), + .size = static_cast(tensor.numel()) * tensor.element_size(), + .dev = static_cast(tensor.device().index()) + }; +} + +inline std::vector serializeTensorsDescs( + const std::vector& descs) { + size_t count = descs.size(); + std::vector buf(sizeof(count) + count * sizeof(TensorDesc)); + std::memcpy(buf.data(), &count, sizeof(count)); + if (count == 0) + return buf; + + std::memcpy( + buf.data() + sizeof(count), + descs.data(), + descs.size() * sizeof(TensorDesc)); + return buf; +} + +inline std::vector deserializeTensorsDescs( + const std::vector& buf) { + NVF_ERROR(buf.size() >= sizeof(size_t), "Invalid serialized descriptor data"); + size_t count; + std::memcpy(&count, buf.data(), sizeof(count)); + NVF_ERROR( + buf.size() == sizeof(count) + count * sizeof(TensorDesc), + "Corrupted serialized descriptor data"); + + std::vector descs(count); + if (count > 0) { + std::memcpy( + descs.data(), + buf.data() + sizeof(count), + count * sizeof(TensorDesc)); + } + return descs; +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& descs) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + communicator.getTcpStore()->set(key, serializeTensorsDescs(descs)); +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& tensors) { + std::vector descs; + descs.reserve(tensors.size()); + for (const auto& tensor : tensors) { + descs.push_back(toTensorDesc(tensor)); + } + storeTensorDescs(communicator, key, descs); +} + +inline std::vector fetchTensorDescs(Communicator& communicator, const std::string& key) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + auto bytes = communicator.getTcpStore()->get(key); + return deserializeTensorsDescs(bytes); +} + +// End of Todo - those functions should be moved to a more global file +// ------------------------------------------------------------------ + +// ------------------------------------------------------------------- +// NixlTransferHandle: opaque handle for a prepared transfer +// ------------------------------------------------------------------- +// Returned by NixlBackend::prepareTransfer(). Callers hold this handle +// and pass it to postTransfer() / waitTransfer(). The actual NIXL +// transfer handle lives inside the impl; this is just an owning wrapper. +class NixlTransferHandleImpl; + +class NVF_API NixlTransferHandle { + public: + NixlTransferHandle(); + ~NixlTransferHandle(); + NixlTransferHandle(NixlTransferHandle&&) noexcept; + NixlTransferHandle& operator=(NixlTransferHandle&&) noexcept; + + NixlTransferHandle(const NixlTransferHandle&) = delete; + NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; + + [[nodiscard]] bool isValid() const; + + private: + friend class NixlBackend; + std::unique_ptr impl_; +}; + +// ------------------------------------------------------------------- +// NixlBackend: singleton NIXL backend over UCX for GPU tensors +// ------------------------------------------------------------------- +// Singleton - Wraps a nixlAgent with the UCX backend and provides a tensor-level +// API for registering GPU memory and performing RDMA transfers. +// +// Lifecycle: +// 1. getInstance() - creates agent, loads UCX backend +// 2. registerTensors() - register GPU tensors for RDMA access +// 3. exchangeMetadata() - all ranks share their registration info +// 4. prepareTransfer() - expensive one-time setup per transfer pattern +// 5. postTransfer() - cheap, non-blocking data movement +// 6. waitTransfer() - block until complete +// +// Thread safety: methods are NOT thread-safe. The caller must +// synchronize if the same NixlBackend is used from multiple threads. +class NixlBackend { + public: + static NixlBackend& getInstance(); + + NixlBackend(const NixlBackend&) = delete; + NixlBackend& operator=(const NixlBackend&) = delete; + ~NixlBackend() = delete; + + // Explicitly tear down the singleton. Must be called before program + // exit (same pattern as Communicator::cleanup). + void cleanup(); + + [[nodiscard]] bool isAvailable() const; + + // ------------------------------------------------------------------ + // Memory registration + // ------------------------------------------------------------------ + + // Register CUDA tensors with the NIXL agent so they can participate + // in RDMA transfers. Tensors must be contiguous and remain alive + // until deregisterTensors() is called. + void registerTensors(const std::vector& tensors); + + void deregisterTensors(const std::vector& tensors); + + // ------------------------------------------------------------------ + // Metadata exchange + // ------------------------------------------------------------------ + // Exchange local agent metadata with all peers through the TCPStore. + // Must be called after registerTensors() and before prepareTransfer() + // whenever the set of registered tensors changes. + void exchangeMetadata(); + + // ------------------------------------------------------------------ + // Transfer lifecycle + // ------------------------------------------------------------------ + + // Prepare a transfer between pairs of tensors. + // local_tensors[i] and remote_tensors[i] must have the same byte size. + // All tensors must be contiguous CUDA tensors and previously registered. + // The returned handle can be posted multiple times (preparation is + // amortized). + [[nodiscard]] NixlTransferHandle prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, + int64_t remote_rank, + NixlXferOp op); + + // Post a previously prepared transfer for execution (non-blocking). + void postTransfer(NixlTransferHandle& handle); + + // Poll the status of a posted transfer without blocking. + [[nodiscard]] NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + + // Block until the transfer completes (or errors out). + void waitTransfer(NixlTransferHandle& handle); + + private: + NixlBackend(); + bool cleaned_up_ = false; + + class Impl; + std::unique_ptr impl_; +}; + + + + +} // namespace nvfuser diff --git a/python/setup.py b/python/setup.py index c7fe063afa7..7a77fc04fae 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,6 +32,9 @@ # NVFUSER_BUILD_WITH_UCC # Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. # +# NVFUSER_BUILD_WITH_NIXL +# Build nvfuser with NIXL support. You may need to set NIXL_PREFIX to the NIXL install directory. +# # NVFUSER_BUILD_WITHOUT_DISTRIBUTED # Build nvfuser without multidevice support # diff --git a/python/utils.py b/python/utils.py index a94d71ee8c9..5d589ebc868 100644 --- a/python/utils.py +++ b/python/utils.py @@ -21,6 +21,7 @@ class BuildConfig: no_benchmark: bool = False no_ninja: bool = False build_with_ucc: bool = False + build_with_nixl: bool = False build_with_asan: bool = False build_without_distributed: bool = False explicit_error_check: bool = False @@ -70,6 +71,8 @@ def override_build_config_from_env(config): config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") if "NVFUSER_BUILD_WITH_UCC" in os.environ: config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_WITH_NIXL" in os.environ: + config.build_with_nixl = get_env_flag_bool("NVFUSER_BUILD_WITH_NIXL") if "NVFUSER_BUILD_WITH_ASAN" in os.environ: config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: @@ -277,6 +280,7 @@ def on_or_off(flag: bool) -> str: f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", f"-DNVFUSER_BUILD_WITH_ASAN={on_or_off(config.build_with_asan)}", f"-DNVFUSER_STANDALONE_BUILD_WITH_UCC={on_or_off(config.build_with_ucc)}", + f"-DNVFUSER_STANDALONE_BUILD_WITH_NIXL={on_or_off(config.build_with_nixl)}", f"-DNVFUSER_EXPLICIT_ERROR_CHECK={on_or_off(config.explicit_error_check)}", f"-DBUILD_TEST={on_or_off(not config.no_test)}", f"-DBUILD_PYTHON={on_or_off(not config.no_python)}", diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp new file mode 100644 index 00000000000..eb8de2ba3b8 --- /dev/null +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -0,0 +1,289 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include "multidevice/nixl.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +using NixlTest = MultiDeviceTest; + +// ------------------------------------------------------------------- +// NixlTransferHandle tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, TransferHandleDefaultConstruction) { + NixlTransferHandle handle; + EXPECT_FALSE(handle.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveConstruction) { + NixlTransferHandle h1; + EXPECT_FALSE(h1.isValid()); + + NixlTransferHandle h2(std::move(h1)); + EXPECT_FALSE(h2.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveAssignment) { + NixlTransferHandle h1; + NixlTransferHandle h2; + h2 = std::move(h1); + EXPECT_FALSE(h2.isValid()); +} + +// ------------------------------------------------------------------- +// NixlBackend singleton tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, SingletonIsAccessible) { + NixlBackend& backend = NixlBackend::getInstance(); + // isAvailable() returns true only when USE_NIXL is defined and the + // UCX backend loaded successfully. Either outcome is valid here. + (void)backend.isAvailable(); +} + +// ------------------------------------------------------------------- +// Input validation tests (these exercise the guards in the impl) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, RegisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.registerTensors(empty), nvfError); +} + +TEST_F(NixlTest, RegisterCpuTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto cpu_tensor = at::randn({64}); + EXPECT_THROW(backend.registerTensors({cpu_tensor}), nvfError); +} + +TEST_F(NixlTest, RegisterNonContiguousTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t = at::randn({8, 8}, tensor_options_); + auto non_contig = t.transpose(0, 1); + ASSERT_FALSE(non_contig.is_contiguous()); + EXPECT_THROW(backend.registerTensors({non_contig}), nvfError); +} + +TEST_F(NixlTest, DeregisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.deregisterTensors(empty), nvfError); +} + +// ------------------------------------------------------------------- +// Transfer preparation validation +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PrepareTransferWithoutMetadataExchangeThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto local = at::randn({64}, tensor_options_); + auto remote = at::randn({64}, tensor_options_); + backend.registerTensors({local}); + backend.registerTensors({remote}); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(local)}, {toTensorDesc(remote)}, 0, NixlXferOp::kRead), + nvfError); + + backend.deregisterTensors({local}); + backend.deregisterTensors({remote}); +} + +TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({64}, tensor_options_); + auto t2 = at::randn({64}, tensor_options_); + auto t3 = at::randn({64}, tensor_options_); + backend.registerTensors({t1, t2, t3}); + backend.exchangeMetadata(); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(t1), toTensorDesc(t2)}, {toTensorDesc(t3)}, 0, NixlXferOp::kRead), nvfError); + + backend.deregisterTensors({t1, t2, t3}); +} + +// ------------------------------------------------------------------- +// Post / wait on invalid handles +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PostInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.postTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, WaitInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.waitTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, GetStatusInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW((void)backend.getTransferStatus(invalid_handle), nvfError); +} + +// ------------------------------------------------------------------- +// End-to-end transfer test (requires >= 2 devices with NIXL) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, ReadTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 1024; + + // Ring style transfer: each rank reads from its peer's remote tensor to its local . + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string src_key_prefix = "nixl_test_read_transfer_src_rank_"; + storeTensorDescs(*communicator_, src_key_prefix + std::to_string(rank), {src}); + auto remote_src_descs = fetchTensorDescs(*communicator_, src_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(src_key_prefix + std::to_string(rank)); + auto remote_src_desc = remote_src_descs[0]; // Only one remote tensor is expected + + // Each rank reads from its peer. After the read, local should contain + // the values that the peer stored in *its* remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(dst)}, {remote_src_desc}, peer_rank, NixlXferOp::kRead); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + auto local_cpu = dst.cpu(); + float expected_val = static_cast(peer_rank + 1); + EXPECT_TRUE(at::allclose(local_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({dst, src}); +} + +TEST_F(NixlTest, WriteTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 512; + + // Each rank writes its local to the remote of its peer in a ring style + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string dst_key_prefix = "nixl_test_write_transfer_dst_rank_"; + storeTensorDescs(*communicator_, dst_key_prefix + std::to_string(rank), {dst}); + auto remote_dst_descs = fetchTensorDescs(*communicator_, dst_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(dst_key_prefix + std::to_string(rank)); + auto remote_dst_desc = remote_dst_descs[0]; // Only one remote tensor is expected + + // Each rank writes its local tensor into its peer's remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(src)}, {remote_dst_desc}, peer_rank, NixlXferOp::kWrite); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + // After a barrier, the peer should have written into our remote tensor . + communicator_->barrier(); + + auto remote_cpu = dst.cpu(); + int64_t writer_rank = (rank - 1 + world_size) % world_size; + float expected_val = static_cast(writer_rank + 1); + EXPECT_TRUE(at::allclose(remote_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({src, dst}); +} + +TEST_F(NixlTest, RegisterDeregisterRoundTrip) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({256}, tensor_options_); + auto t2 = at::randn({128}, tensor_options_); + + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); + + // Re-registering the same tensors should succeed. + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); +} + +} // namespace nvfuser