From cf77bdbd23a43303e0a4d3234c2d062d14a9a2aa Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 02:08:34 -0800 Subject: [PATCH 01/15] first working dispatch and combine primitive for k=1 --- CMakeLists.txt | 2 + csrc/dispatch.h | 2 + csrc/host_ir/evaluator.cpp | 63 +++++ csrc/host_ir/evaluator.h | 2 + csrc/multidevice/communication.cpp | 161 +++++++++++ csrc/multidevice/communication.h | 162 +++++++++++ csrc/multidevice/dispatch_combine.cpp | 267 ++++++++++++++++++ csrc/multidevice/dispatch_combine.h | 51 ++++ .../cpp/test_multidevice_dispatch_combine.cpp | 121 ++++++++ 9 files changed, 831 insertions(+) create mode 100644 csrc/multidevice/dispatch_combine.cpp create mode 100644 csrc/multidevice/dispatch_combine.h create mode 100644 tests/cpp/test_multidevice_dispatch_combine.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 13dd918282b..b325b325d9c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/communication.cpp ${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp ${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp + ${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp @@ -1143,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 3bf3b8350ff..01aa278af71 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -118,6 +118,8 @@ class Val; f(Merge); \ f(Partition); \ f(Combine); \ + f(MoEDispatch); \ + f(MoECombine); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2ceedfddc40..a847a9d5f99 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -25,6 +25,7 @@ #include "multidevice/allocation_utils.h" #include "multidevice/communication.h" #include "multidevice/cuda_p2p.h" +#include "multidevice/dispatch_combine.h" #include "multidevice/execution_utils.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" @@ -386,6 +387,68 @@ void HostIrEvaluator::handle(P2PCommunication* communication) { } } +void HostIrEvaluator::handle(MoEDispatch* dispatch) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(dispatch->inX()).as(); + auto topk_idx = + getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_weights = + getKnownConcreteValue(dispatch->inTopkWeights()).as(); + auto is_token_in_rank = + getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); + + auto result = dispatchWithCudaBackend( + x, + topk_idx, + topk_weights, + is_token_in_rank, + dispatch->numExperts(), + communicator_, + dispatch->backend()); + + expr_evaluator_.bind(dispatch->outX(), result.recv_x); + expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); + expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); + expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); + expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); + expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); + expr_evaluator_.bind( + dispatch->outTokensFromRank(), result.n_tokens_from_rank); +} + +void HostIrEvaluator::handle(MoECombine* combine) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(combine->inX()).as(); + auto topk_weights = + getKnownConcreteValue(combine->inTopkWeights()).as(); + auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); + auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); + auto n_tokens_to_rank = + getKnownConcreteValue(combine->inTokensToRank()).as(); + auto n_tokens_from_rank = + getKnownConcreteValue(combine->inTokensFromRank()).as(); + + auto result = combineWithCudaBackend( + x, + topk_weights, + src_idx, + src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + communicator_, + combine->backend()); + + expr_evaluator_.bind(combine->outX(), result.combined_x); + expr_evaluator_.bind( + combine->outTopkWeights(), result.combined_topk_weights); +} + void HostIrEvaluator::handle(Wait* wait) { Expr* expr = wait->communication(); auto* p2p_comm = dynamic_cast(expr); diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..c1b0a70ef78 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(LaunchKernel*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; + void handle(MoEDispatch*) override; + void handle(MoECombine*) override; void handle(Wait*) override; void handle(kir::ForLoop*) override; void handle(hir::ForLoop*) override; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 06b4ffa426c..febbd519d10 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -321,6 +321,167 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +MoEDispatch::MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_idx); + addInput(in_topk_weights); + addInput(in_is_token_in_rank); + addOutput(out_x); + addOutput(out_topk_idx); + addOutput(out_topk_weights); + addOutput(out_src_idx); + addOutput(out_src_rank); + addOutput(out_n_tokens_to_rank); + addOutput(out_n_tokens_from_rank); + addDataAttribute(num_experts); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoEDispatch) + +std::string MoEDispatch::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Dispatch " << name() << " (" + << "num_experts=" << numExperts() << ", " + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "topk_idx=" << inTopkIdx() << ", " + << "topk_weights=" << inTopkWeights() << ", " + << "is_token_in_rank=" << inIsTokenInRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoEDispatch::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoEDispatch::validate() { + NVF_CHECK(numExperts() > 0, "num_experts must be positive."); + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK(inTopkIdx()->isA(), "topk_idx must be a TensorView."); + NVF_CHECK( + inTopkIdx()->getDataType().has_value() && + isIntegralType(*inTopkIdx()->getDataType()), + "topk_idx must be integral."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "topk_weights must be floating point."); + NVF_CHECK( + inIsTokenInRank()->getDataType() == DataType::Bool, + "is_token_in_rank must be Bool."); + NVF_CHECK( + outTopkIdx()->getDataType().has_value() && + isIntegralType(*outTopkIdx()->getDataType()), + "out_topk_idx must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); + NVF_CHECK( + outSrcIdx()->getDataType().has_value() && + isIntegralType(*outSrcIdx()->getDataType()), + "out_src_idx must be integral."); + NVF_CHECK( + outSrcRank()->getDataType().has_value() && + isIntegralType(*outSrcRank()->getDataType()), + "out_src_rank must be integral."); + NVF_CHECK( + outTokensToRank()->getDataType().has_value() && + isIntegralType(*outTokensToRank()->getDataType()), + "out_n_tokens_to_rank must be integral."); + NVF_CHECK( + outTokensFromRank()->getDataType().has_value() && + isIntegralType(*outTokensFromRank()->getDataType()), + "out_n_tokens_from_rank must be integral."); +} + +MoECombine::MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_weights); + addInput(in_src_idx); + addInput(in_src_rank); + addInput(in_n_tokens_to_rank); + addInput(in_n_tokens_from_rank); + addOutput(out_x); + addOutput(out_topk_weights); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoECombine) + +std::string MoECombine::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Combine " << name() << " (" + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "src_idx=" << inSrcIdx() << ", " + << "src_rank=" << inSrcRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoECombine::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoECombine::validate() { + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "in_topk_weights must be floating point."); + NVF_CHECK( + inSrcIdx()->getDataType().has_value() && + isIntegralType(*inSrcIdx()->getDataType()), + "in_src_idx must be integral."); + NVF_CHECK( + inSrcRank()->getDataType().has_value() && + isIntegralType(*inSrcRank()->getDataType()), + "in_src_rank must be integral."); + NVF_CHECK( + inTokensToRank()->getDataType().has_value() && + isIntegralType(*inTokensToRank()->getDataType()), + "in_n_tokens_to_rank must be integral."); + NVF_CHECK( + inTokensFromRank()->getDataType().has_value() && + isIntegralType(*inTokensFromRank()->getDataType()), + "in_n_tokens_from_rank must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..9c880110b5e 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -174,6 +174,168 @@ class P2PCommunication : public Expr { } }; +// Dispatch represents intra-node MoE token dispatch. It shuffles tokens from +// the local rank to destination ranks based on `is_token_in_rank`. +class MoEDispatch : public Expr { + public: + using Expr::Expr; + + MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoEDispatch(const MoEDispatch& other) = delete; + MoEDispatch& operator=(const MoEDispatch& other) = delete; + MoEDispatch(MoEDispatch&& other) = delete; + MoEDispatch& operator=(MoEDispatch&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoEDispatch"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkIdx() const { + return output(1)->as(); + } + + TensorView* outTopkWeights() const { + return output(2)->as(); + } + + TensorView* outSrcIdx() const { + return output(3)->as(); + } + + TensorView* outSrcRank() const { + return output(4)->as(); + } + + TensorView* outTokensToRank() const { + return output(5)->as(); + } + + TensorView* outTokensFromRank() const { + return output(6)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkIdx() const { + return input(1)->as(); + } + + TensorView* inTopkWeights() const { + return input(2)->as(); + } + + TensorView* inIsTokenInRank() const { + return input(3)->as(); + } + + int64_t numExperts() const { + return attribute(0); + } + + CommunicatorBackend backend() const { + return attribute(1); + } + + private: + void validate(); +}; + +// Combine represents intra-node MoE token combine. It shuffles tokens back to +// their source ranks using `src_rank` and `src_idx`. +class MoECombine : public Expr { + public: + using Expr::Expr; + + MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoECombine(const MoECombine& other) = delete; + MoECombine& operator=(const MoECombine& other) = delete; + MoECombine(MoECombine&& other) = delete; + MoECombine& operator=(MoECombine&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoECombine"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkWeights() const { + return output(1)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkWeights() const { + return input(1)->as(); + } + + TensorView* inSrcIdx() const { + return input(2)->as(); + } + + TensorView* inSrcRank() const { + return input(3)->as(); + } + + TensorView* inTokensToRank() const { + return input(4)->as(); + } + + TensorView* inTokensFromRank() const { + return input(5)->as(); + } + + CommunicatorBackend backend() const { + return attribute(0); + } + + private: + void validate(); +}; + // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp new file mode 100644 index 00000000000..7ac888c539a --- /dev/null +++ b/csrc/multidevice/dispatch_combine.cpp @@ -0,0 +1,267 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "multidevice/dispatch_combine.h" + +#include +#include + +#include + +#include "multidevice/communicator.h" +#include "utils.h" + +namespace nvfuser { +namespace { + +CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { + if (backend == CommunicatorBackend::kCuda) { + return CommunicatorBackend::kNccl; + } + return backend; +} + +std::vector toSplitSizes(const at::Tensor& sizes_tensor) { + auto cpu_sizes = sizes_tensor.to(at::kCPU); + auto* ptr = cpu_sizes.data_ptr(); + return std::vector(ptr, ptr + cpu_sizes.numel()); +} + +int64_t sumSplitSizes(const std::vector& splits) { + int64_t total = 0; + for (auto value : splits) { + total += value; + } + return total; +} + +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + if (topk.numel() == num_tokens) { + return topk.reshape({num_tokens}); + } + if (topk.dim() == 2 && topk.size(0) == num_tokens && + topk.size(1) == 1) { + return topk.reshape({num_tokens}); + } + NVF_CHECK( + false, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", + topk.sizes()); +} + +void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + NVF_CHECK( + min_val == 1 && max_val == 1, + "Only topk=1 is supported. Each token must be assigned to exactly one rank."); +} + +} // namespace + +DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); + NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); + NVF_CHECK( + is_token_in_rank.is_cuda(), + "Dispatch is_token_in_rank must be on CUDA."); + NVF_CHECK( + is_token_in_rank.dim() == 2, + "is_token_in_rank must be 2D [tokens, ranks], got: ", + is_token_in_rank.sizes()); + NVF_CHECK( + x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + + const int64_t num_tokens = x.size(0); + const int64_t hidden = x.size(1); + const int64_t world_size = communicator->size(); + const int64_t my_rank = communicator->deviceId(); + NVF_CHECK( + is_token_in_rank.size(1) == world_size, + "is_token_in_rank second dim must match world size."); + NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); + + c10::cuda::CUDAGuard device_guard(x.device()); + ensureTopk1Assignment(is_token_in_rank); + + auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); + auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); + + auto rank_for_token = + is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + auto sorted = rank_for_token.sort(); + auto sorted_indices = std::get<1>(sorted); + + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_src_idx = sorted_indices.to(at::kLong); + auto send_src_rank = at::full( + {num_tokens}, + my_rank, + at::TensorOptions().dtype(at::kLong).device(x.device())); + send_src_rank = send_src_rank.index_select(0, sorted_indices); + + auto rank_for_token_cpu = rank_for_token.to(at::kCPU); + auto n_tokens_to_rank_cpu = + at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); + auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for dispatch: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + std::vector one_split(world_size, 1); + if (auto work = pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { + work->wait(); + } + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)) { + work->wait(); + } + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; +} + +CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); + NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); + NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), + "Combine n_tokens_from_rank must be CUDA."); + NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); + NVF_CHECK( + src_idx.dim() == 1 && src_rank.dim() == 1, + "src_idx and src_rank must be 1D."); + NVF_CHECK( + n_tokens_to_rank.numel() == communicator->size(), + "n_tokens_to_rank must match world size."); + NVF_CHECK( + n_tokens_from_rank.numel() == communicator->size(), + "n_tokens_from_rank must match world size."); + + c10::cuda::CUDAGuard device_guard(x.device()); + + auto sorted = src_rank.sort(); + auto sorted_indices = std::get<1>(sorted); + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights.index_select(0, sorted_indices); + auto send_src_idx = src_idx.index_select(0, sorted_indices); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for combine: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; +} + +} // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h new file mode 100644 index 00000000000..0d8f75c9f6d --- /dev/null +++ b/csrc/multidevice/dispatch_combine.h @@ -0,0 +1,51 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +struct DispatchResult { + at::Tensor recv_x; + at::Tensor recv_topk_idx; + at::Tensor recv_topk_weights; + at::Tensor recv_src_idx; + at::Tensor recv_src_rank; + at::Tensor n_tokens_to_rank; + at::Tensor n_tokens_from_rank; +}; + +struct CombineResult { + at::Tensor combined_x; + at::Tensor combined_topk_weights; +}; + +NVF_API DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend); + +NVF_API CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp new file mode 100644 index 00000000000..be13743c8b8 --- /dev/null +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include + +#include "fusion.h" +#include "host_ir/container.h" +#include "host_ir/evaluator.h" +#include "ir/all_nodes.h" +#include "multidevice/communication.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class DispatchCombineTest : public MultiDeviceTest {}; + +TEST_F(DispatchCombineTest, DispatchCombineTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 8; + constexpr int64_t kHidden = 4; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); + + auto* recv_x = makeSymbolicTensor(2); + auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* dispatch = IrBuilder::create( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + in_x, + in_topk_idx, + in_topk_weights, + in_is_token_in_rank, + num_experts, + CommunicatorBackend::kCuda); + + auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combine = IrBuilder::create( + combined_x, + combined_topk_weights, + recv_x, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + CommunicatorBackend::kCuda); + + hic->pushBackTopLevelExprs(dispatch); + hic->pushBackTopLevelExprs(combine); + + hic->addInput(in_x); + hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); + hic->addInput(in_is_token_in_rank); + hic->addOutput(combined_x); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = + (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_weights = at::ones({kNumTokens}, float_options); + + auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + auto rank_ids = at::arange(world_size, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + auto outputs = hie.runWithInput( + {{in_x, x}, + {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, + {in_is_token_in_rank, is_token_in_rank}}); + auto combined = outputs.back().as(); + + EXPECT_TRUE(at::allclose(combined, x)) + << "Dispatch/Combine mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser From 66e7811afa48f0ce819a66fd3191a699842d4254 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 05:27:05 -0800 Subject: [PATCH 02/15] add comments and cleanup --- csrc/host_ir/evaluator.cpp | 10 +- csrc/multidevice/communication.h | 16 +- csrc/multidevice/dispatch_combine.cpp | 152 +++++++++--------- csrc/multidevice/dispatch_combine.h | 97 +++++++++-- .../cpp/test_multidevice_dispatch_combine.cpp | 21 ++- 5 files changed, 186 insertions(+), 110 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index a847a9d5f99..5f6bb83227d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -393,14 +393,13 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { "A valid communicator must be provided"); auto x = getKnownConcreteValue(dispatch->inX()).as(); - auto topk_idx = - getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); auto topk_weights = getKnownConcreteValue(dispatch->inTopkWeights()).as(); auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); - auto result = dispatchWithCudaBackend( + auto result = doMoEDispatch( x, topk_idx, topk_weights, @@ -434,7 +433,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { auto n_tokens_from_rank = getKnownConcreteValue(combine->inTokensFromRank()).as(); - auto result = combineWithCudaBackend( + auto result = doMoECombine( x, topk_weights, src_idx, @@ -445,8 +444,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); - expr_evaluator_.bind( - combine->outTopkWeights(), result.combined_topk_weights); + expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 9c880110b5e..a3f806b6c64 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -175,7 +175,13 @@ class P2PCommunication : public Expr { }; // Dispatch represents intra-node MoE token dispatch. It shuffles tokens from -// the local rank to destination ranks based on `is_token_in_rank`. +// the local rank to destination ranks based on `in_is_token_in_rank`. +// +// Example shapes (topk=1): +// in_x: [T, H], in_topk_idx: [T] or [T, 1], in_topk_weights: [T] or [T, 1], +// in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. +// Outputs are recv-aligned tensors: out_x/out_topk_*/out_src_* with [T_recv, +// ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. class MoEDispatch : public Expr { public: using Expr::Expr; @@ -266,7 +272,13 @@ class MoEDispatch : public Expr { }; // Combine represents intra-node MoE token combine. It shuffles tokens back to -// their source ranks using `src_rank` and `src_idx`. +// their source ranks using `in_src_rank` and `in_src_idx`. +// +// Example shapes (topk=1): +// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, +// ...]. class MoECombine : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 7ac888c539a..738e27765d9 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -19,13 +19,6 @@ namespace nvfuser { namespace { -CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { - if (backend == CommunicatorBackend::kCuda) { - return CommunicatorBackend::kNccl; - } - return backend; -} - std::vector toSplitSizes(const at::Tensor& sizes_tensor) { auto cpu_sizes = sizes_tensor.to(at::kCPU); auto* ptr = cpu_sizes.data_ptr(); @@ -40,32 +33,27 @@ int64_t sumSplitSizes(const std::vector& splits) { return total; } -at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { - if (topk.numel() == num_tokens) { - return topk.reshape({num_tokens}); - } - if (topk.dim() == 2 && topk.size(0) == num_tokens && - topk.size(1) == 1) { - return topk.reshape({num_tokens}); +void waitWork(const c10::intrusive_ptr& work) { + if (work) { + work->wait(); } - NVF_CHECK( - false, - "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", - topk.sizes()); } -void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { - auto token_counts = is_token_in_rank.to(at::kLong).sum(1); - auto min_val = token_counts.min().item(); - auto max_val = token_counts.max().item(); +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + const bool is_1d = topk.dim() == 1 && topk.size(0) == num_tokens; + const bool is_2d = + topk.dim() == 2 && topk.size(0) == num_tokens && topk.size(1) == 1; NVF_CHECK( - min_val == 1 && max_val == 1, - "Only topk=1 is supported. Each token must be assigned to exactly one rank."); + is_1d || is_2d, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], " + "got: ", + topk.sizes()); + return topk.reshape({num_tokens}); } } // namespace -DispatchResult dispatchWithCudaBackend( +DispatchResult doMoEDispatch( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, @@ -78,14 +66,12 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); NVF_CHECK( - is_token_in_rank.is_cuda(), - "Dispatch is_token_in_rank must be on CUDA."); + is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); NVF_CHECK( is_token_in_rank.dim() == 2, "is_token_in_rank must be 2D [tokens, ranks], got: ", is_token_in_rank.sizes()); - NVF_CHECK( - x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + NVF_CHECK(x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); @@ -97,33 +83,49 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); c10::cuda::CUDAGuard device_guard(x.device()); - ensureTopk1Assignment(is_token_in_rank); + NVF_CHECK( + [&]() { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + return min_val == 1 && max_val == 1; + }(), + "Only topk=1 is supported. Each token must be assigned to exactly one " + "rank."); auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); - auto rank_for_token = - is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Determine destination rank per token (topk=1). + auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Sort tokens by destination rank for contiguous alltoall slices. auto sorted = rank_for_token.sort(); auto sorted_indices = std::get<1>(sorted); + // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); + // All entries are identical, so no relayout is needed. auto send_src_rank = at::full( {num_tokens}, my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - send_src_rank = send_src_rank.index_select(0, sorted_indices); + // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we + // sync/copy here. GPU-initiated comms can avoid this extra sync. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoEDispatch."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for dispatch: ", @@ -131,43 +133,36 @@ DispatchResult dispatchWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + // Exchange per-rank token counts to build split sizes for alltoall. std::vector one_split(world_size, 1); - if (auto work = pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { - work->wait(); - } + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + // Convert count tensors to CPU split vectors and size the receive buffers. auto input_splits = toSplitSizes(n_tokens_to_rank); auto output_splits = toSplitSizes(n_tokens_from_rank); auto total_recv = sumSplitSizes(output_splits); + // Allocate receive buffers for payloads and metadata. + // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)) { - work->wait(); - } - + // Alltoall exchange payloads with per-rank splits. + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; auto local_expert = recv_topk_idx - my_rank * experts_per_rank; auto expert_sorted = local_expert.sort(); @@ -188,7 +183,7 @@ DispatchResult dispatchWithCudaBackend( n_tokens_from_rank}; } -CombineResult combineWithCudaBackend( +CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, @@ -202,10 +197,10 @@ CombineResult combineWithCudaBackend( NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); - NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); NVF_CHECK( - n_tokens_from_rank.is_cuda(), - "Combine n_tokens_from_rank must be CUDA."); + n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK( src_idx.dim() == 1 && src_rank.dim() == 1, @@ -219,18 +214,23 @@ CombineResult combineWithCudaBackend( c10::cuda::CUDAGuard device_guard(x.device()); + // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted = src_rank.sort(); auto sorted_indices = std::get<1>(sorted); auto send_x = x.index_select(0, sorted_indices); auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); + // Split sizes come from dispatch counts. auto input_splits = toSplitSizes(n_tokens_from_rank); auto output_splits = toSplitSizes(n_tokens_to_rank); auto total_recv = sumSplitSizes(output_splits); auto hidden = x.size(1); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoECombine."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for combine: ", @@ -238,27 +238,21 @@ CombineResult combineWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Combine backend is null."); + // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = - at::empty({total_recv}, topk_weights.options()); + auto combined_topk_weights = at::empty({total_recv}, topk_weights.options()); combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); return CombineResult{combined_x, combined_topk_weights}; diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 0d8f75c9f6d..5714a45a818 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -15,30 +15,95 @@ namespace nvfuser { struct DispatchResult { - at::Tensor recv_x; - at::Tensor recv_topk_idx; - at::Tensor recv_topk_weights; - at::Tensor recv_src_idx; - at::Tensor recv_src_rank; - at::Tensor n_tokens_to_rank; - at::Tensor n_tokens_from_rank; + at::Tensor recv_x; // Dispatched tokens received on this rank. + at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. + at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. + at::Tensor recv_src_idx; // Source token indices for combine. + at::Tensor recv_src_rank; // Source ranks for combine. + at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). + at::Tensor n_tokens_from_rank; // Tokens received from each rank. }; struct CombineResult { - at::Tensor combined_x; - at::Tensor combined_topk_weights; + at::Tensor combined_x; // Combined tokens back in original order. + at::Tensor combined_topk_weights; // Combined gating weights per token. }; -NVF_API DispatchResult dispatchWithCudaBackend( - const at::Tensor& x, - const at::Tensor& topk_idx, - const at::Tensor& topk_weights, - const at::Tensor& is_token_in_rank, +// Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. +// +// Args: +// x: Token embeddings on this rank, shape [T, H]. +// topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. +// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. +// num_experts: Total experts across all ranks (must be divisible by R). +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). +// +// Returns: +// DispatchResult with recv_* tensors on this rank. +// +// Example: +// // world_size=2, num_experts=4, T=4, H=2, topk=1 +// // Experts are partitioned by rank: +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: +// // rank0 x = [x0, x1], rank1 x = [x2, x3] +// // token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) +// // is_token_in_rank = +// // [[1, 0], +// // [0, 1], +// // [0, 1], +// // [0, 1]] +// // topk_idx = [0, 2, 3, 2] (global expert ids) +// // After dispatch on rank0: +// // recv_x has token {0} +// // recv_topk_idx aligned with recv_x (e.g., [0]) +// // recv_src_idx tells original token positions (e.g., [0]) +// // After dispatch on rank1: +// // recv_x has tokens {1, 2, 3} +// // recv_topk_idx aligned with recv_x (e.g., [2, 3, 2]) +// // recv_src_idx tells original token positions (e.g., [1, 2, 3]) +// auto out = doMoEDispatch( +// x, topk_idx, topk_weights, is_token_in_rank, 4, comm, +// CommunicatorBackend::kNccl); +NVF_API DispatchResult doMoEDispatch( + const at::Tensor& x, // [T, H] + const at::Tensor& topk_idx, // [T] or [T, 1] + const at::Tensor& topk_weights, // [T] or [T, 1] + const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, CommunicatorBackend backend); -NVF_API CombineResult combineWithCudaBackend( +// Combine dispatched MoE results back to original token order. +// +// Args: +// x: Token embeddings after expert compute, shape [T_recv, H]. +// topk_weights: Gating weights aligned with x, shape [T_recv]. +// src_idx: Original token indices for each row of x, shape [T_recv]. +// src_rank: Original source rank per token, shape [T_recv]. +// n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. +// n_tokens_from_rank: Tokens received from each rank (from dispatch), shape +// [R]. communicator: Communicator for alltoall exchange. backend: +// Communication backend (only NCCL is supported for now). +// +// Returns: +// CombineResult with tokens restored to original order on this rank. +// +// Example: +// // Continuing the dispatch example (experts partitioned by rank): +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // After expert compute: +// // rank0 recv_x has token {0} with src_idx = [0], src_rank = [0] +// // rank1 recv_x has tokens {1, 2, 3} with src_idx = [1, 2, 3], +// // src_rank = [0, 1, 1] +// // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. +// // Combine scatters results back to original token order per rank. +// auto combined = doMoECombine( +// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// n_tokens_from_rank, comm, CommunicatorBackend::kNccl); +NVF_API CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index be13743c8b8..0d84dbc03e0 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -32,7 +32,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { const int64_t my_rank = communicator_->deviceId(); constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; - constexpr int64_t kNumTokens = 8; + constexpr int64_t kNumTokens = 4; constexpr int64_t kHidden = 4; auto hic = std::make_unique(); @@ -64,7 +64,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +77,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -98,14 +98,21 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = - (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_idx = at::zeros({kNumTokens}, int_options); auto topk_weights = at::ones({kNumTokens}, float_options); - auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, From dda9aa7c2be35ef1e604fb12b63d8a5278834657 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 22 Jan 2026 09:33:18 -0800 Subject: [PATCH 03/15] add kernel based a2av and cuda backend for d/c --- CMakeLists.txt | 2 + csrc/multidevice/alltoallv.cu | 37 ++ csrc/multidevice/cuda_p2p.cpp | 315 ++++++++++++++++++ csrc/multidevice/cuda_p2p.h | 29 ++ csrc/multidevice/dispatch_combine.cpp | 309 +++++++++++++---- csrc/multidevice/dispatch_combine.h | 4 +- tests/cpp/test_multidevice_alltoallv.cpp | 82 +++++ .../cpp/test_multidevice_dispatch_combine.cpp | 20 +- 8 files changed, 726 insertions(+), 72 deletions(-) create mode 100644 csrc/multidevice/alltoallv.cu create mode 100644 tests/cpp/test_multidevice_alltoallv.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b325b325d9c..ff76e741b4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1144,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_alltoallv.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp @@ -1393,6 +1394,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/mbarrier.cu ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu + ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/csrc/multidevice/alltoallv.cu b/csrc/multidevice/alltoallv.cu new file mode 100644 index 00000000000..9725794f838 --- /dev/null +++ b/csrc/multidevice/alltoallv.cu @@ -0,0 +1,37 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +extern "C" __global__ void alltoallv_kernel( + const unsigned char* send, + const unsigned long long* recv_ptrs, + const long long* send_offsets, + const long long* send_sizes, + const long long* recv_offsets, + long long world_size, + long long elem_size, + long long max_send_bytes) { + const long long peer = static_cast(blockIdx.y); + if (peer >= world_size) { + return; + } + const long long bytes = send_sizes[peer] * elem_size; + if (bytes == 0) { + return; + } + const long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= bytes) { + return; + } + const long long send_byte_offset = send_offsets[peer] * elem_size + idx; + const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; + auto* dst = reinterpret_cast( + static_cast(recv_ptrs[peer])); + dst[recv_byte_offset] = send[send_byte_offset]; +} + diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index 6ad709fa062..8804c1a7a79 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include "multidevice/cuda_p2p.h" +#include "nvfuser_resources/alltoallv.h" #include "nvfuser_resources/multicast.h" #include "cuda_utils.h" @@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { } namespace { +void launchAlltoallvKernel( + const void* send, + const uint64_t* recv_ptrs, + const int64_t* send_offsets, + const int64_t* send_sizes, + const int64_t* recv_offsets, + int64_t world_size, + int64_t elem_size, + int64_t max_send_bytes, + CUstream stream) { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + + if (module == nullptr) { + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, + nvfuser_resources::alltoallv_cu, + "alltoallv.cu", + 0, + nullptr, + nullptr)); + + int major = 0; + int minor = 0; + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); + major = prop.major; + minor = prop.minor; + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + std::vector opts = {arch_arg.c_str(), "--std=c++17"}; + // NVRTC needs CUDA headers to compile alltoallv.cu. + opts.push_back("-I/usr/local/cuda/include"); + opts.push_back("-I/usr/local/cuda/include/cccl"); + + nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + CUresult load_result = cuModuleLoadData(&module, ptx.data()); + if (load_result != CUDA_SUCCESS) { + constexpr size_t kLogSize = 8192; + char error_log[kLogSize]; + char info_log[kLogSize]; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_LOG_VERBOSE}; + void* option_values[] = { + (void*)error_log, + (void*)kLogSize, + (void*)info_log, + (void*)kLogSize, + (void*)1}; + cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); + NVF_ERROR( + false, + "Alltoallv kernel module load failed with error: ", + load_result, + "\nInfo Log:\n", + info_log, + "\nError Log:\n", + error_log); + } + + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); + } + + if (max_send_bytes == 0) { + return; + } + + constexpr int kThreads = 256; + const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; + void* args_kernel[] = { + const_cast(static_cast(&send)), + const_cast(static_cast(&recv_ptrs)), + const_cast(static_cast(&send_offsets)), + const_cast(static_cast(&send_sizes)), + const_cast(static_cast(&recv_offsets)), + &world_size, + &elem_size, + &max_send_bytes}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + kernel, + blocks_x, + static_cast(world_size), + 1, + kThreads, + 1, + 1, + 0, + stream, + args_kernel, + nullptr)); +} + +std::vector serializeInt64Vector(const std::vector& values) { + std::vector bytes(values.size() * sizeof(int64_t)); + std::memcpy(bytes.data(), values.data(), bytes.size()); + return bytes; +} + +std::vector deserializeInt64Vector(const std::vector& bytes) { + NVF_CHECK( + bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); + const size_t count = bytes.size() / sizeof(int64_t); + std::vector values(count); + std::memcpy(values.data(), bytes.data(), bytes.size()); + return values; +} + +std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); +} + +std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); +} void launchMulticastKernel( void* dst, @@ -710,4 +848,181 @@ void waitWithCudaBackend( } } +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + const int64_t world_size = comm.size(); + const int64_t my_rank = comm.deviceId(); + NVF_CHECK( + send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); + NVF_CHECK( + send_counts.dim() == 1 && send_counts.numel() == world_size, + "alltoallv send_counts must be 1D [R]."); + + auto store = comm.getTcpStore(); + auto send_counts_cpu = send_counts.to(at::kCPU); + auto* send_ptr = send_counts_cpu.data_ptr(); + std::vector send_counts_vec(send_ptr, send_ptr + world_size); + + store->set( + alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); + + std::vector> counts_matrix(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + auto bytes = store->get(alltoallvCountsKey(tag, rank)); + counts_matrix[rank] = deserializeInt64Vector(bytes); + NVF_CHECK( + (int64_t)counts_matrix[rank].size() == world_size, + "Invalid alltoallv counts size."); + } + comm.barrier(); + for (int64_t rank = 0; rank < world_size; ++rank) { + store->deleteKey(alltoallvCountsKey(tag, rank)); + } + + std::vector recv_counts_vec(world_size, 0); + for (int64_t sender = 0; sender < world_size; ++sender) { + recv_counts_vec[sender] = counts_matrix[sender][my_rank]; + } + + std::vector send_offsets_vec(world_size, 0); + int64_t prefix = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + send_offsets_vec[rank] = prefix; + prefix += send_counts_vec[rank]; + } + + std::vector recv_offsets_vec(world_size, 0); + for (int64_t peer = 0; peer < world_size; ++peer) { + int64_t offset = 0; + for (int64_t sender = 0; sender < my_rank; ++sender) { + offset += counts_matrix[sender][peer]; + } + recv_offsets_vec[peer] = offset; + } + + int64_t total_recv = 0; + for (auto value : recv_counts_vec) { + total_recv += value; + } + + int64_t max_recv = 0; + int64_t max_send_total = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t sender = 0; sender < world_size; ++sender) { + total += counts_matrix[sender][rank]; + } + if (total > max_recv) { + max_recv = total; + } + } + + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t dest = 0; dest < world_size; ++dest) { + total += counts_matrix[rank][dest]; + } + if (total > max_send_total) { + max_send_total = total; + } + } + + int64_t max_send = 0; + for (auto value : send_counts_vec) { + if (value > max_send) { + max_send = value; + } + } + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto send_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + send_offsets_cpu.data_ptr(), + send_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_offsets_cpu.data_ptr(), + recv_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_counts_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_counts_cpu.data_ptr(), + recv_counts_vec.data(), + world_size * sizeof(int64_t)); + + AlltoallvMetadata metadata; + metadata.send_counts = send_counts; + metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); + metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); + metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); + metadata.total_recv = total_recv; + metadata.max_recv = max_recv; + metadata.max_send_total = max_send_total; + metadata.max_send_bytes = max_send; + metadata.world_size = world_size; + return metadata; +} + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream) { + NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); + NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); + NVF_CHECK( + (int64_t)recv_ptrs.size() == metadata.world_size, + "recv_ptrs size must match world size."); + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); + auto* ptrs = recv_ptrs_cpu.data_ptr(); + for (int64_t rank = 0; rank < metadata.world_size; ++rank) { + ptrs[rank] = + static_cast(reinterpret_cast(recv_ptrs[rank])); + } + auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); + + const int64_t elem_stride = + metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; + NVF_CHECK( + metadata.max_send_total == 0 || + send.numel() % metadata.max_send_total == 0, + "alltoallv send numel must be divisible by max_send_total."); + NVF_CHECK( + metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, + "alltoallv recv numel must be divisible by max_recv."); + + auto send_offsets = metadata.send_offsets; + auto send_counts = metadata.send_counts; + auto recv_offsets = metadata.recv_offsets; + int64_t max_send_bytes = metadata.max_send_bytes; + if (elem_stride > 1) { + send_offsets = metadata.send_offsets * elem_stride; + send_counts = metadata.send_counts * elem_stride; + recv_offsets = metadata.recv_offsets * elem_stride; + max_send_bytes = metadata.max_send_bytes * elem_stride; + } + + launchAlltoallvKernel( + send.data_ptr(), + reinterpret_cast(recv_ptrs_cuda.data_ptr()), + send_offsets.data_ptr(), + send_counts.data_ptr(), + recv_offsets.data_ptr(), + metadata.world_size, + send.element_size(), + max_send_bytes * send.element_size(), + stream); +} + +void alltoallvBarrier(const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + comm.barrier(); +} + } // namespace nvfuser diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 4947e4e6ee1..e9fd5828597 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -9,6 +9,10 @@ #include +#include +#include +#include + #include "multidevice/ipc_handle.h" namespace nvfuser { @@ -43,4 +47,29 @@ void waitWithCudaBackend( CUstream stream, int64_t root); +struct AlltoallvMetadata { + at::Tensor send_counts; // CUDA [R] + at::Tensor recv_counts; // CUDA [R] + at::Tensor send_offsets; // CUDA [R] + at::Tensor recv_offsets; // CUDA [R] + int64_t total_recv = 0; + int64_t max_recv = 0; + int64_t max_send_total = 0; + int64_t max_send_bytes = 0; + int64_t world_size = 0; +}; + +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag); + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream); + +void alltoallvBarrier(const std::string& tag); + } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 738e27765d9..cbad812aa06 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -11,9 +11,12 @@ #include #include +#include #include #include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" #include "utils.h" namespace nvfuser { @@ -114,53 +117,160 @@ DispatchResult doMoEDispatch( my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we - // sync/copy here. GPU-initiated comms can avoid this extra sync. + // Split metadata is exchanged via CPU (TCPStore), so we sync/copy here. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); - auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for dispatch: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + std::vector one_split(world_size, 1); + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoEDispatch."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for dispatch: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Dispatch backend is null."); - - // Exchange per-rank token counts to build split sizes for alltoall. - std::vector one_split(world_size, 1); - waitWork(pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); - - // Convert count tensors to CPU split vectors and size the receive buffers. - auto input_splits = toSplitSizes(n_tokens_to_rank); - auto output_splits = toSplitSizes(n_tokens_from_rank); - auto total_recv = sumSplitSizes(output_splits); - - // Allocate receive buffers for payloads and metadata. - // TODO: support preallocated buffers. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); - auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); - auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - - // Alltoall exchange payloads with per-rank splits. - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoEDispatch."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts"); + auto n_tokens_from_rank = metadata.recv_counts; + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, num_tokens).copy_(send_x); + auto send_topk_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_idx_flat.scalar_type(), x.device()); + send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights_flat.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx); + auto send_src_rank_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_rank.scalar_type(), x.device()); + send_src_rank_sym.narrow(0, 0, num_tokens).copy_(send_src_rank); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_idx_sym = SymmetricTensor::allocate( + {max_recv}, topk_idx_flat.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights_flat.scalar_type(), x.device()); + auto recv_src_idx_sym = SymmetricTensor::allocate( + {max_recv}, send_src_idx.scalar_type(), x.device()); + auto recv_src_rank_sym = SymmetricTensor::allocate( + {max_recv}, send_src_rank.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + SymmetricTensor recv_src_rank_handle(recv_src_rank_sym); + recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x"); + recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx"); + recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx"); + recv_src_rank_handle.setupRemoteHandles("moe_dispatch_recv_src_rank"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_idx_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + std::vector recv_src_rank_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_idx_ptrs[rank] = + recv_topk_idx_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + recv_src_rank_ptrs[rank] = + recv_src_rank_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_idx_sym, + recv_topk_idx_sym, + metadata, + recv_topk_idx_ptrs, + stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvWithCudaBackend( + send_src_rank_sym, + recv_src_rank_sym, + metadata, + recv_src_rank_ptrs, + stream); + alltoallvBarrier("moe_dispatch_counts"); + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); + auto recv_src_rank = recv_src_rank_sym.narrow(0, 0, total_recv); // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; @@ -212,6 +322,7 @@ CombineResult doMoECombine( n_tokens_from_rank.numel() == communicator->size(), "n_tokens_from_rank must match world size."); + const int64_t world_size = communicator->size(); c10::cuda::CUDAGuard device_guard(x.device()); // Sort by source rank so alltoall can send contiguous chunks per rank. @@ -222,32 +333,100 @@ CombineResult doMoECombine( auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. - auto input_splits = toSplitSizes(n_tokens_from_rank); - auto output_splits = toSplitSizes(n_tokens_to_rank); - auto total_recv = sumSplitSizes(output_splits); - auto hidden = x.size(1); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for combine: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoECombine."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for combine: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Combine backend is null."); - - // Allocate receive buffers and exchange payloads back to source ranks. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); - auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoECombine."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_from_rank, "moe_combine_counts"); + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + auto hidden = x.size(1); + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, x.size(0)).copy_(send_x); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, x.size(0)).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, x.size(0)).copy_(send_src_idx); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights.scalar_type(), x.device()); + auto recv_src_idx_sym = + SymmetricTensor::allocate({max_recv}, src_idx.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + recv_x_handle.setupRemoteHandles("moe_combine_recv_x"); + recv_topk_weights_handle.setupRemoteHandles("moe_combine_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_combine_recv_src_idx"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvBarrier("moe_combine_counts"); + + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 5714a45a818..ceb0a2652b4 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -38,7 +38,7 @@ struct CombineResult { // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. -// backend: Communication backend (only NCCL is supported for now). +// backend: Communication backend (CUDA or NCCL). // // Returns: // DispatchResult with recv_* tensors on this rank. @@ -86,7 +86,7 @@ NVF_API DispatchResult doMoEDispatch( // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape // [R]. communicator: Communicator for alltoall exchange. backend: -// Communication backend (only NCCL is supported for now). +// Communication backend (CUDA or NCCL). // // Returns: // CombineResult with tokens restored to original order on this rank. diff --git a/tests/cpp/test_multidevice_alltoallv.cpp b/tests/cpp/test_multidevice_alltoallv.cpp new file mode 100644 index 00000000000..02cb21b7892 --- /dev/null +++ b/tests/cpp/test_multidevice_alltoallv.cpp @@ -0,0 +1,82 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include + +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class AlltoallvCudaTest : public MultiDeviceTest {}; + +TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto count_for = [](int64_t sender, int64_t dest) { + return (sender + dest) % 3 + 1; + }; + auto send_counts = at::empty({world_size}, int_options); + for (int64_t dest = 0; dest < world_size; ++dest) { + send_counts.index_put_({dest}, count_for(my_rank, dest)); + } + + auto metadata = prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); + const int64_t max_recv = metadata.max_recv; + const int64_t total_send = send_counts.sum().item(); + auto send_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, at::kLong, communicator_->device()); + send_sym.narrow(0, 0, total_send) + .copy_(at::arange(total_send, int_options) + my_rank * 1000); + + auto recv_sym = SymmetricTensor::allocate( + {max_recv}, at::kLong, communicator_->device()); + SymmetricTensor recv_handle(recv_sym); + recv_handle.setupRemoteHandles("test_alltoallv_recv"); + + std::vector recv_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_ptrs[rank] = recv_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = at::cuda::getDefaultCUDAStream().stream(); + alltoallvWithCudaBackend(send_sym, recv_sym, metadata, recv_ptrs, stream); + alltoallvBarrier("test_alltoallv_counts"); + + auto recv_view = recv_sym.narrow(0, 0, metadata.total_recv); + std::vector expected_vec; + expected_vec.reserve(static_cast(metadata.total_recv)); + for (int64_t sender = 0; sender < world_size; ++sender) { + int64_t offset = 0; + for (int64_t dest = 0; dest < my_rank; ++dest) { + offset += count_for(sender, dest); + } + const int64_t count = count_for(sender, my_rank); + for (int64_t i = 0; i < count; ++i) { + expected_vec.push_back(offset + i + sender * 1000); + } + } + auto expected = at::tensor(expected_vec, int_options); + EXPECT_TRUE(at::equal(recv_view, expected)) + << "Alltoallv mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 0d84dbc03e0..1a28c6e18d5 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -21,15 +21,21 @@ namespace nvfuser { namespace hir { -class DispatchCombineTest : public MultiDeviceTest {}; +class DispatchCombineTest + : public MultiDeviceTest, + public ::testing::WithParamInterface {}; -TEST_F(DispatchCombineTest, DispatchCombineTop1) { +TEST_P(DispatchCombineTest, DispatchCombineTop1) { if (!communicator_->is_available() || communicator_->size() < 2) { GTEST_SKIP() << "This test needs at least 2 ranks."; } const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); + const auto backend = GetParam(); + if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { + GTEST_SKIP() << "Backend " << backend << " not available."; + } constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; constexpr int64_t kNumTokens = 4; @@ -64,7 +70,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kNccl); + backend); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +83,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kNccl); + backend); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -119,10 +125,14 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); auto combined = outputs.back().as(); - EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; } +INSTANTIATE_TEST_SUITE_P( + DispatchCombineBackends, + DispatchCombineTest, + ::testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kCuda)); + } // namespace hir } // namespace nvfuser From 7aa2de86034d62071317abc40b94d90fafe1f253 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 25 Feb 2026 13:35:54 +0200 Subject: [PATCH 04/15] unstable - add nixl backend --- csrc/multidevice/nixl.cpp | 447 ++++++++++++++++++++++++++++ csrc/multidevice/nixl.h | 139 +++++++++ tests/cpp/test_multidevice_nixl.cpp | 289 ++++++++++++++++++ 3 files changed, 875 insertions(+) create mode 100644 csrc/multidevice/nixl.cpp create mode 100644 csrc/multidevice/nixl.h create mode 100644 tests/cpp/test_multidevice_nixl.cpp diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp new file mode 100644 index 00000000000..3f71c267cfd --- /dev/null +++ b/csrc/multidevice/nixl.cpp @@ -0,0 +1,447 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include "multidevice/nixl.h" + +#include + +#ifdef USE_NIXL +#include +#endif + +namespace nvfuser { + +// =================================================================== +// NixlTransferHandle +// =================================================================== + +class NixlTransferHandleImpl { + public: +#ifdef USE_NIXL + nixl_xfer_req_t xfer_handle{}; + bool prepared = false; + bool posted = false; +#endif +}; + +NixlTransferHandle::NixlTransferHandle() = default; +NixlTransferHandle::~NixlTransferHandle() = default; +NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = + default; +NixlTransferHandle& NixlTransferHandle::operator=( + NixlTransferHandle&&) noexcept = default; + +bool NixlTransferHandle::isValid() const { + if (!impl_) { + return false; + } +#ifdef USE_NIXL + return impl_->prepared; +#else + return false; +#endif +} + +// =================================================================== +// Tensor validation and descriptor helpers +// =================================================================== + +namespace { + +void validateCudaTensors(const std::vector& tensors) { + NVF_ERROR(!tensors.empty(), "Tensor list must not be empty"); + for (const auto& t : tensors) { + NVF_ERROR(t.is_cuda(), "All tensors must be CUDA tensors"); + NVF_ERROR(t.is_contiguous(), "All tensors must be contiguous"); + } +} + +#ifdef USE_NIXL +nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { + nixl_reg_dlist_t dlist(VRAM, tensors.size()); + for (const auto& t : tensors) { + dlist.addDesc( + {reinterpret_cast(t.data_ptr()), + static_cast(t.numel()) * t.element_size(), + static_cast(t.device().index())}); + } + return dlist; +} + +nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { + nixl_xfer_dlist_t dlist(VRAM, tensors.size()); + for (const auto& t : tensors) { + dlist.addDesc( + {reinterpret_cast(t.data_ptr()), + static_cast(t.numel()) * t.element_size(), + static_cast(t.device().index())}); + } + return dlist; +} + +nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { + switch (op) { + case NixlXferOp::kRead: + return NIXL_XFER_READ; + case NixlXferOp::kWrite: + return NIXL_XFER_WRITE; + } + std::unreachable(); +} +#endif + +} // namespace + +// =================================================================== +// NixlBackend::Impl +// =================================================================== + +class NixlBackend::Impl { + public: + explicit Impl(Communicator& communicator); + ~Impl(); + + bool isAvailable() const { + return available_; + } + + void registerTensors(const std::vector& tensors); + void deregisterTensors(const std::vector& tensors); + void exchangeMetadata(); + + NixlTransferHandle prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op); + + void postTransfer(NixlTransferHandle& handle); + NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + void waitTransfer(NixlTransferHandle& handle); + + private: +#ifdef USE_NIXL + std::unique_ptr agent_; +#endif + Communicator& communicator_; + bool available_ = false; + bool metadata_exchanged_ = false; +}; + +// ------------------------------------------------------------------- +// Construction / destruction +// ------------------------------------------------------------------- + +NixlBackend::Impl::Impl(Communicator& communicator) + : communicator_(communicator) { +#ifdef USE_NIXL + std::string agent_name = constructAgentName(communicator_.deviceId()); + agent_ = std::make_unique(agent_name); + if (!agent_) { + NVF_THROW("Failed to create NIXL agent"); + } + + nixl_b_params_t params; + nixl_status_t status = agent_->loadBackend("UCX", ¶ms); + if (status != NIXL_SUCCESS) { + agent_.reset(); + NVF_THROW("Failed to load UCX backend for NIXL agent"); + return; + } + + available_ = true; +#endif +} + +NixlBackend::Impl::~Impl() { +#ifdef USE_NIXL + agent_.reset(); +#endif +} + +std::string NixlBackend::Impl::constructAgentName(int deviceId){ + return "rank_" + std::to_string(deviceId); +} + +// ------------------------------------------------------------------- +// Memory registration +// ------------------------------------------------------------------- + +void NixlBackend::Impl::registerTensors( + const std::vector& tensors) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->registerMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL registerMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +#else + (void)tensors; + NVF_THROW("NIXL support not compiled"); +#endif +} + +void NixlBackend::Impl::deregisterTensors( + const std::vector& tensors) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->deregisterMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL deregisterMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +#else + (void)tensors; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Metadata exchange +// ------------------------------------------------------------------- + +void NixlBackend::Impl::exchangeMetadata() { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + + std::string local_md = agent_->getLocalMD(); + auto* store = communicator_.getTcpStore(); + const int64_t my_rank = communicator_.deviceId(); + const int64_t world_size = communicator_.size(); + + std::string key_prefix = "nixl_agent_md_rank_"; + store->set( + key_prefix + std::to_string(my_rank), + std::vector(local_md.begin(), local_md.end())); + + for (int64_t rank = 0; rank < world_size; ++rank) { + if (rank == my_rank) { + continue; + } + auto bytes = store->get(key_prefix + std::to_string(rank)); + std::string remote_md(bytes.begin(), bytes.end()); + nixl_status_t status = agent_->loadRemoteMD(remote_md); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL loadRemoteMD failed for rank ", + rank, + " with status ", + static_cast(status)); + } + + // Barrier before deleting keys so no rank reads a deleted key. + communicator_.barrier(); + + store->deleteKey(key_prefix + std::to_string(my_rank)); + metadata_exchanged_ = true; +#else + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Transfer preparation +// ------------------------------------------------------------------- + +// Prepare a transfer between local and remote tensor pairs. +// +// The local and remote descriptor lists are built from the tensors' +// data pointers, byte sizes, and CUDA device indices. NIXL pairs +// local_tensors[i] with remote_tensors[i]. The direction depends on `op`: +// kRead -- data flows from remote_tensors[i] into local_tensors[i] +// kWrite -- data flows from local_tensors[i] into remote_tensors[i] +// +// Preconditions: +// - exchangeMetadata() has been called since the last registration change +// - local_tensors and remote_tensors have the same length +// - all tensors are contiguous CUDA tensors +// - remote tensors must have been registered on remote_rank's agent +NixlTransferHandle NixlBackend::Impl::prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op) { + NixlTransferHandle handle; +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); + NVF_ERROR( + local_tensors.size() == remote_tensors.size(), + "Local and remote tensor lists must have the same size. Got ", + local_tensors.size(), + " vs ", + remote_tensors.size()); + validateCudaTensors(local_tensors); + validateCudaTensors(remote_tensors); + + std::string remote_agent_name = constructAgentName(remote_rank); + + nixl_xfer_dlist_t local_dlist = buildXferDlist(local_tensors); + nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_tensors); + + auto impl = std::make_unique(); + nixl_status_t status = agent_->prepXferDlist( + toNixlXferOp(op), + local_dlist, + remote_dlist, + remote_agent_name, + impl->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL prepXferDlist failed with status ", + static_cast(status)); + + impl->prepared = true; + handle.impl_ = std::move(impl); +#else + (void)local_tensors; + (void)remote_tensors; + (void)remote_rank; + (void)op; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif + return handle; +} + +// ------------------------------------------------------------------- +// Transfer posting +// ------------------------------------------------------------------- + +void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); + NVF_ERROR( + !handle.impl_->posted, + "Transfer already posted. Wait for completion before re-posting."); + + nixl_status_t status = agent_->postXferReq(handle.impl_->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS || status == NIXL_IN_PROG, + "NIXL postXferReq failed with status ", + static_cast(status)); + + handle.impl_->posted = true; +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Transfer status / wait +// ------------------------------------------------------------------- + +NixlXferStatus NixlBackend::Impl::getTransferStatus( + const NixlTransferHandle& handle) const { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + nixl_status_t status = agent_->getXferStatus(handle.impl_->xfer_handle); + switch (status) { + case NIXL_SUCCESS: + return NixlXferStatus::kDone; + case NIXL_IN_PROG: + return NixlXferStatus::kInProgress; + default: + return NixlXferStatus::kError; + } +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + NixlXferStatus xfer_status; + do { + xfer_status = getTransferStatus(handle); + NVF_ERROR( + xfer_status != NixlXferStatus::kError, + "NIXL transfer completed with an error"); + } while (xfer_status == NixlXferStatus::kInProgress); + + handle.impl_->posted = false; +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// =================================================================== +// NixlBackend singleton + public API +// =================================================================== + +NixlBackend::NixlBackend() + : impl_(std::make_unique(Communicator::getInstance())) {} + +NixlBackend& NixlBackend::getInstance() { + static auto* instance = new NixlBackend(); + return *instance; +} + +void NixlBackend::cleanup() { + impl_.reset(); +} + +bool NixlBackend::isAvailable() const { + return impl_ && impl_->isAvailable(); +} + +void NixlBackend::registerTensors(const std::vector& tensors) { + impl_->registerTensors(tensors); +} + +void NixlBackend::deregisterTensors(const std::vector& tensors) { + impl_->deregisterTensors(tensors); +} + +void NixlBackend::exchangeMetadata() { + impl_->exchangeMetadata(); +} + +NixlTransferHandle NixlBackend::prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op) { + return impl_->prepareTransfer( + local_tensors, remote_tensors, remote_rank, op); +} + +void NixlBackend::postTransfer(NixlTransferHandle& handle) { + impl_->postTransfer(handle); +} + +NixlXferStatus NixlBackend::getTransferStatus( + const NixlTransferHandle& handle) const { + return impl_->getTransferStatus(handle); +} + +void NixlBackend::waitTransfer(NixlTransferHandle& handle) { + impl_->waitTransfer(handle); +} + +} // namespace nvfuser diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h new file mode 100644 index 00000000000..3b1dedf7cb9 --- /dev/null +++ b/csrc/multidevice/nixl.h @@ -0,0 +1,139 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +// Transfer direction. NIXL uses a one-sided model: +// Read = pull remote data into local buffers +// Write = push local data into remote buffers +enum class NixlXferOp { + kRead, + kWrite, +}; + +enum class NixlXferStatus { + kDone, + kInProgress, + kError, +}; + +// ------------------------------------------------------------------- +// NixlTransferHandle: opaque handle for a prepared transfer +// ------------------------------------------------------------------- +// Returned by NixlBackend::prepareTransfer(). Callers hold this handle +// and pass it to postTransfer() / waitTransfer(). The actual NIXL +// transfer handle lives inside the impl; this is just an owning wrapper. +class NixlTransferHandleImpl; + +class NVF_API NixlTransferHandle { + public: + NixlTransferHandle(); + ~NixlTransferHandle(); + NixlTransferHandle(NixlTransferHandle&&) noexcept; + NixlTransferHandle& operator=(NixlTransferHandle&&) noexcept; + + NixlTransferHandle(const NixlTransferHandle&) = delete; + NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; + + bool isValid() const; + + private: + friend class NixlBackend; + std::unique_ptr impl_; +}; + +// ------------------------------------------------------------------- +// NixlBackend: singleton NIXL backend over UCX for GPU tensors +// ------------------------------------------------------------------- +// Singleton - Wraps a nixlAgent with the UCX backend and provides a tensor-level +// API for registering GPU memory and performing RDMA transfers. +// +// Lifecycle: +// 1. getInstance() - creates agent, loads UCX backend +// 2. registerTensors() - register GPU tensors for RDMA access +// 3. exchangeMetadata() - all ranks share their registration info +// 4. prepareTransfer() - expensive one-time setup per transfer pattern +// 5. postTransfer() - cheap, non-blocking data movement +// 6. waitTransfer() - block until complete +// +// Thread safety: methods are NOT thread-safe. The caller must +// synchronize if the same NixlBackend is used from multiple threads. +class NixlBackend { + public: + static NixlBackend& getInstance(); + + NixlBackend(const NixlBackend&) = delete; + NixlBackend& operator=(const NixlBackend&) = delete; + ~NixlBackend() = delete; + + // Explicitly tear down the singleton. Must be called before program + // exit (same pattern as Communicator::cleanup). + void cleanup(); + + bool isAvailable() const; + + // ------------------------------------------------------------------ + // Memory registration + // ------------------------------------------------------------------ + + // Register CUDA tensors with the NIXL agent so they can participate + // in RDMA transfers. Tensors must be contiguous and remain alive + // until deregisterTensors() is called. + void registerTensors(const std::vector& tensors); + + void deregisterTensors(const std::vector& tensors); + + // ------------------------------------------------------------------ + // Metadata exchange + // ------------------------------------------------------------------ + // Exchange local agent metadata with all peers through the TCPStore. + // Must be called after registerTensors() and before prepareTransfer() + // whenever the set of registered tensors changes. + void exchangeMetadata(); + + // ------------------------------------------------------------------ + // Transfer lifecycle + // ------------------------------------------------------------------ + + // Prepare a transfer between pairs of tensors. + // local_tensors[i] and remote_tensors[i] must have the same byte size. + // All tensors must be contiguous CUDA tensors and previously registered. + // The returned handle can be posted multiple times (preparation is + // amortized). + NixlTransferHandle prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op); + + // Post a previously prepared transfer for execution (non-blocking). + void postTransfer(NixlTransferHandle& handle); + + // Poll the status of a posted transfer without blocking. + NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + + // Block until the transfer completes (or errors out). + void waitTransfer(NixlTransferHandle& handle); + + private: + NixlBackend(); + + class Impl; + std::unique_ptr impl_; +}; + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp new file mode 100644 index 00000000000..eb8de2ba3b8 --- /dev/null +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -0,0 +1,289 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include "multidevice/nixl.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +using NixlTest = MultiDeviceTest; + +// ------------------------------------------------------------------- +// NixlTransferHandle tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, TransferHandleDefaultConstruction) { + NixlTransferHandle handle; + EXPECT_FALSE(handle.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveConstruction) { + NixlTransferHandle h1; + EXPECT_FALSE(h1.isValid()); + + NixlTransferHandle h2(std::move(h1)); + EXPECT_FALSE(h2.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveAssignment) { + NixlTransferHandle h1; + NixlTransferHandle h2; + h2 = std::move(h1); + EXPECT_FALSE(h2.isValid()); +} + +// ------------------------------------------------------------------- +// NixlBackend singleton tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, SingletonIsAccessible) { + NixlBackend& backend = NixlBackend::getInstance(); + // isAvailable() returns true only when USE_NIXL is defined and the + // UCX backend loaded successfully. Either outcome is valid here. + (void)backend.isAvailable(); +} + +// ------------------------------------------------------------------- +// Input validation tests (these exercise the guards in the impl) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, RegisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.registerTensors(empty), nvfError); +} + +TEST_F(NixlTest, RegisterCpuTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto cpu_tensor = at::randn({64}); + EXPECT_THROW(backend.registerTensors({cpu_tensor}), nvfError); +} + +TEST_F(NixlTest, RegisterNonContiguousTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t = at::randn({8, 8}, tensor_options_); + auto non_contig = t.transpose(0, 1); + ASSERT_FALSE(non_contig.is_contiguous()); + EXPECT_THROW(backend.registerTensors({non_contig}), nvfError); +} + +TEST_F(NixlTest, DeregisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.deregisterTensors(empty), nvfError); +} + +// ------------------------------------------------------------------- +// Transfer preparation validation +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PrepareTransferWithoutMetadataExchangeThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto local = at::randn({64}, tensor_options_); + auto remote = at::randn({64}, tensor_options_); + backend.registerTensors({local}); + backend.registerTensors({remote}); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(local)}, {toTensorDesc(remote)}, 0, NixlXferOp::kRead), + nvfError); + + backend.deregisterTensors({local}); + backend.deregisterTensors({remote}); +} + +TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({64}, tensor_options_); + auto t2 = at::randn({64}, tensor_options_); + auto t3 = at::randn({64}, tensor_options_); + backend.registerTensors({t1, t2, t3}); + backend.exchangeMetadata(); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(t1), toTensorDesc(t2)}, {toTensorDesc(t3)}, 0, NixlXferOp::kRead), nvfError); + + backend.deregisterTensors({t1, t2, t3}); +} + +// ------------------------------------------------------------------- +// Post / wait on invalid handles +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PostInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.postTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, WaitInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.waitTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, GetStatusInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW((void)backend.getTransferStatus(invalid_handle), nvfError); +} + +// ------------------------------------------------------------------- +// End-to-end transfer test (requires >= 2 devices with NIXL) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, ReadTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 1024; + + // Ring style transfer: each rank reads from its peer's remote tensor to its local . + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string src_key_prefix = "nixl_test_read_transfer_src_rank_"; + storeTensorDescs(*communicator_, src_key_prefix + std::to_string(rank), {src}); + auto remote_src_descs = fetchTensorDescs(*communicator_, src_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(src_key_prefix + std::to_string(rank)); + auto remote_src_desc = remote_src_descs[0]; // Only one remote tensor is expected + + // Each rank reads from its peer. After the read, local should contain + // the values that the peer stored in *its* remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(dst)}, {remote_src_desc}, peer_rank, NixlXferOp::kRead); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + auto local_cpu = dst.cpu(); + float expected_val = static_cast(peer_rank + 1); + EXPECT_TRUE(at::allclose(local_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({dst, src}); +} + +TEST_F(NixlTest, WriteTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 512; + + // Each rank writes its local to the remote of its peer in a ring style + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string dst_key_prefix = "nixl_test_write_transfer_dst_rank_"; + storeTensorDescs(*communicator_, dst_key_prefix + std::to_string(rank), {dst}); + auto remote_dst_descs = fetchTensorDescs(*communicator_, dst_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(dst_key_prefix + std::to_string(rank)); + auto remote_dst_desc = remote_dst_descs[0]; // Only one remote tensor is expected + + // Each rank writes its local tensor into its peer's remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(src)}, {remote_dst_desc}, peer_rank, NixlXferOp::kWrite); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + // After a barrier, the peer should have written into our remote tensor . + communicator_->barrier(); + + auto remote_cpu = dst.cpu(); + int64_t writer_rank = (rank - 1 + world_size) % world_size; + float expected_val = static_cast(writer_rank + 1); + EXPECT_TRUE(at::allclose(remote_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({src, dst}); +} + +TEST_F(NixlTest, RegisterDeregisterRoundTrip) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({256}, tensor_options_); + auto t2 = at::randn({128}, tensor_options_); + + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); + + // Re-registering the same tensors should succeed. + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); +} + +} // namespace nvfuser From 9a8a377109a5038c55e4a53d11b7e26ee53cf4b9 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:36:56 +0200 Subject: [PATCH 05/15] unstable --- CMakeLists.txt | 39 ++++++ csrc/multidevice/communicator.cpp | 3 + csrc/multidevice/communicator.h | 11 ++ csrc/multidevice/multidevice.h | 2 +- csrc/multidevice/nixl.cpp | 189 +++++++++++++++++++----------- csrc/multidevice/nixl.h | 105 ++++++++++++++++- 6 files changed, 276 insertions(+), 73 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ff76e741b4c..94b87209c21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ set(NVFUSER_CUTLASS "${NVFUSER_ROOT}/cutlass") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) +option(NVFUSER_STANDALONE_BUILD_WITH_NIXL "" OFF) option(NVFUSER_EXPLICIT_ERROR_CHECK "" OFF) option(NVFUSER_ENABLE_DEPENDENCY_REPORT "Enable Python-based dependency reporting and log capture" ON) @@ -240,6 +241,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp ${NVFUSER_SRCS_DIR}/multidevice/executor.cpp + ${NVFUSER_SRCS_DIR}/multidevice/nixl.cpp ${NVFUSER_SRCS_DIR}/multidevice/execution_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/propagation.cpp ${NVFUSER_SRCS_DIR}/multidevice/resharding.cpp @@ -586,6 +588,37 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + # User may need to set NIXL_PREFIX to the NIXL install directory. + find_path(NIXL_INCLUDE_DIR nixl.h + HINTS $ENV{NIXL_PREFIX}/include ENV CPATH + ) + find_library(NIXL_LIBRARY nixl + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + find_library(NIXL_BUILD_LIBRARY nixl_build + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + + if(NOT NIXL_INCLUDE_DIR OR NOT NIXL_LIBRARY) + message(FATAL_ERROR "NIXL not found. Set NIXL_PREFIX to the NIXL install directory.") + endif() + + message(STATUS "Found NIXL: ${NIXL_LIBRARY} (include: ${NIXL_INCLUDE_DIR})") + if(NIXL_BUILD_LIBRARY) + message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") + endif() + + add_library(__nvfuser_nixl INTERFACE) + target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) + if(NIXL_BUILD_LIBRARY) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) + endif() + target_link_libraries(codegen_internal PRIVATE __nvfuser_nixl) + target_compile_definitions(codegen_internal PRIVATE USE_NIXL) +endif() + add_dependencies(codegen_internal flatc build_flatbuffer_config) # installing nvfuser headers @@ -1153,6 +1186,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_nixl.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -1457,6 +1491,11 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") +message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_NIXL : ${NVFUSER_STANDALONE_BUILD_WITH_NIXL}") +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") + message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") +endif() message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 208277f98a1..98c36ab8d3b 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -41,6 +41,9 @@ std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb) { case CommunicatorBackend::kCuda: out << "CUDA"; break; + case CommunicatorBackend::kNixl: + out << "NIXL"; + break; } return out; } diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index b56e6fee3aa..c4a1eb3d09b 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,6 +11,9 @@ #include #include +#include +#include + #ifdef NVFUSER_DISTRIBUTED #include #include @@ -116,6 +119,12 @@ class NVF_API Communicator { return ucc_available_; } else if (backend == CommunicatorBackend::kNccl) { return nccl_available_; + } else if (backend == CommunicatorBackend::kNixl) { +#ifdef USE_NIXL + return true; +#else + return false; +#endif } return false; } @@ -124,6 +133,7 @@ class NVF_API Communicator { return store_.get(); } + private: Communicator( CommunicatorBackend backend = comm_backend_default, @@ -155,4 +165,5 @@ class NVF_API Communicator { std::unordered_map> backends_; }; + } // namespace nvfuser diff --git a/csrc/multidevice/multidevice.h b/csrc/multidevice/multidevice.h index 288a89fe952..7915f5e3d92 100644 --- a/csrc/multidevice/multidevice.h +++ b/csrc/multidevice/multidevice.h @@ -19,5 +19,5 @@ using DeviceType = c10::Device; using Team = std::vector; // Supported backends. -enum class CommunicatorBackend { kNccl, kUcc, kCuda }; +enum class CommunicatorBackend { kNccl, kUcc, kCuda, kNixl }; } // namespace nvfuser diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 3f71c267cfd..fa0ac7c7a94 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -6,13 +6,16 @@ */ // clang-format on #include "multidevice/nixl.h" +#include "exceptions.h" +#include +#include +#include #include #ifdef USE_NIXL #include #endif - namespace nvfuser { // =================================================================== @@ -22,10 +25,11 @@ namespace nvfuser { class NixlTransferHandleImpl { public: #ifdef USE_NIXL - nixl_xfer_req_t xfer_handle{}; + // TODO - is it leaking when handleimpl is destroyed ? + nixlXferReqH* xfer_handle = nullptr; +#endif bool prepared = false; bool posted = false; -#endif }; NixlTransferHandle::NixlTransferHandle() = default; @@ -61,8 +65,9 @@ void validateCudaTensors(const std::vector& tensors) { } #ifdef USE_NIXL + nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { - nixl_reg_dlist_t dlist(VRAM, tensors.size()); + nixl_reg_dlist_t dlist(VRAM_SEG, tensors.size()); for (const auto& t : tensors) { dlist.addDesc( {reinterpret_cast(t.data_ptr()), @@ -72,13 +77,10 @@ nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { return dlist; } -nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { - nixl_xfer_dlist_t dlist(VRAM, tensors.size()); - for (const auto& t : tensors) { - dlist.addDesc( - {reinterpret_cast(t.data_ptr()), - static_cast(t.numel()) * t.element_size(), - static_cast(t.device().index())}); +nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { + nixl_xfer_dlist_t dlist(VRAM_SEG, descs.size()); + for (const auto& desc : descs) { + dlist.addDesc({desc.addr, desc.size, desc.dev}); } return dlist; } @@ -86,12 +88,13 @@ nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { switch (op) { case NixlXferOp::kRead: - return NIXL_XFER_READ; + return NIXL_READ; case NixlXferOp::kWrite: - return NIXL_XFER_WRITE; + return NIXL_WRITE; } - std::unreachable(); + NVF_THROW("Invalid NIXL transfer operation: ", static_cast(op)); } + #endif } // namespace @@ -114,8 +117,8 @@ class NixlBackend::Impl { void exchangeMetadata(); NixlTransferHandle prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -124,8 +127,11 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: + std::string constructAgentName(int64_t rank); + #ifdef USE_NIXL std::unique_ptr agent_; + nixlBackendH* backend_ = nullptr; #endif Communicator& communicator_; bool available_ = false; @@ -140,37 +146,73 @@ NixlBackend::Impl::Impl(Communicator& communicator) : communicator_(communicator) { #ifdef USE_NIXL std::string agent_name = constructAgentName(communicator_.deviceId()); - agent_ = std::make_unique(agent_name); - if (!agent_) { - NVF_THROW("Failed to create NIXL agent"); - } + nixlAgentConfig cfg(false); + agent_ = std::make_unique(agent_name, cfg); nixl_b_params_t params; - nixl_status_t status = agent_->loadBackend("UCX", ¶ms); + nixl_status_t status = agent_->createBackend("UCX", params, backend_); if (status != NIXL_SUCCESS) { agent_.reset(); - NVF_THROW("Failed to load UCX backend for NIXL agent"); - return; + NVF_THROW("Failed to create UCX backend for NIXL agent"); + } + + // Probe: verify that VRAM (CUDA GPU memory) is actually usable with + // the UCX backend. Some UCX installations lack CUDA support, causing + // registerMem to silently misclassify VRAM as host memory. We detect + // this by registering a small buffer and asking NIXL to prepare a + // local descriptor list for VRAM -- if no backend claims VRAM, the + // probe fails and we mark the backend as unavailable. + { + auto probe = at::empty( + {1}, + at::TensorOptions().dtype(at::kByte).device( + at::kCUDA, communicator_.deviceId())); + nixl_reg_dlist_t reg_dlist(VRAM_SEG, 1); + reg_dlist.addDesc( + {reinterpret_cast(probe.data_ptr()), + probe.nbytes(), + static_cast(probe.device().index())}); + + nixl_status_t reg_status = agent_->registerMem(reg_dlist); + if (reg_status != NIXL_SUCCESS) { + return; + } + + nixl_xfer_dlist_t xfer_dlist(VRAM_SEG, 1); + xfer_dlist.addDesc( + {reinterpret_cast(probe.data_ptr()), + probe.nbytes(), + static_cast(probe.device().index())}); + + nixlDlistH* dlist_handle = nullptr; + nixl_status_t prep_status = + agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + + if (dlist_handle) { + agent_->releasedDlistH(dlist_handle); + } + agent_->deregisterMem(reg_dlist); + + if (prep_status != NIXL_SUCCESS) { + return; + } } available_ = true; #endif } -NixlBackend::Impl::~Impl() { -#ifdef USE_NIXL - agent_.reset(); -#endif -} +NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::constructAgentName(int deviceId){ - return "rank_" + std::to_string(deviceId); +std::string NixlBackend::Impl::constructAgentName(int64_t rank){ + return "rank_" + std::to_string(rank); } // ------------------------------------------------------------------- // Memory registration // ------------------------------------------------------------------- +// TODO - consider adding RAII wrapper void NixlBackend::Impl::registerTensors( const std::vector& tensors) { #ifdef USE_NIXL @@ -219,23 +261,31 @@ void NixlBackend::Impl::exchangeMetadata() { #ifdef USE_NIXL NVF_ERROR(available_, "NIXL backend is not available"); - std::string local_md = agent_->getLocalMD(); + nixl_blob_t local_md; + nixl_status_t md_status = agent_->getLocalMD(local_md); + NVF_ERROR( + md_status == NIXL_SUCCESS, + "NIXL getLocalMD failed with status ", + static_cast(md_status)); + auto* store = communicator_.getTcpStore(); - const int64_t my_rank = communicator_.deviceId(); - const int64_t world_size = communicator_.size(); + const auto my_rank = communicator_.deviceId(); + const auto world_size = communicator_.size(); - std::string key_prefix = "nixl_agent_md_rank_"; + std::string md_key_prefix = "nixl_agent_md_rank_"; store->set( - key_prefix + std::to_string(my_rank), + md_key_prefix + std::to_string(my_rank), std::vector(local_md.begin(), local_md.end())); for (int64_t rank = 0; rank < world_size; ++rank) { if (rank == my_rank) { continue; } - auto bytes = store->get(key_prefix + std::to_string(rank)); - std::string remote_md(bytes.begin(), bytes.end()); - nixl_status_t status = agent_->loadRemoteMD(remote_md); + // Fetch & load MD + auto bytes = store->get(md_key_prefix + std::to_string(rank)); + nixl_blob_t remote_md(bytes.begin(), bytes.end()); + std::string remote_agent_name; + nixl_status_t status = agent_->loadRemoteMD(remote_md, remote_agent_name); NVF_ERROR( status == NIXL_SUCCESS, "NIXL loadRemoteMD failed for rank ", @@ -247,7 +297,7 @@ void NixlBackend::Impl::exchangeMetadata() { // Barrier before deleting keys so no rank reads a deleted key. communicator_.barrier(); - store->deleteKey(key_prefix + std::to_string(my_rank)); + store->deleteKey(md_key_prefix + std::to_string(my_rank)); metadata_exchanged_ = true; #else NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); @@ -260,20 +310,18 @@ void NixlBackend::Impl::exchangeMetadata() { // Prepare a transfer between local and remote tensor pairs. // -// The local and remote descriptor lists are built from the tensors' -// data pointers, byte sizes, and CUDA device indices. NIXL pairs -// local_tensors[i] with remote_tensors[i]. The direction depends on `op`: -// kRead -- data flows from remote_tensors[i] into local_tensors[i] -// kWrite -- data flows from local_tensors[i] into remote_tensors[i] +// NIXL pairs local_tensors[i] with remote_tensors[i]. The direction +// depends on `op`: +// kRead -- data flows from remote into local +// kWrite -- data flows from local into remote // -// Preconditions: -// - exchangeMetadata() has been called since the last registration change -// - local_tensors and remote_tensors have the same length -// - all tensors are contiguous CUDA tensors -// - remote tensors must have been registered on remote_rank's agent +// remote_tensors are LOCAL tensors whose data_ptr identifies the +// corresponding registration slot. The actual remote addresses are +// looked up from the descriptors exchanged during exchangeMetadata(). +// This requires all ranks to register tensors in the same order. NixlTransferHandle NixlBackend::Impl::prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, // Local addresses + const std::vector& remote_descs, // Remote tensors (not valid on this rank) int64_t remote_rank, NixlXferOp op) { NixlTransferHandle handle; @@ -281,21 +329,19 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); NVF_ERROR( - local_tensors.size() == remote_tensors.size(), + local_descs.size() == remote_descs.size(), "Local and remote tensor lists must have the same size. Got ", - local_tensors.size(), + local_descs.size(), " vs ", - remote_tensors.size()); - validateCudaTensors(local_tensors); - validateCudaTensors(remote_tensors); + remote_descs.size()); std::string remote_agent_name = constructAgentName(remote_rank); - nixl_xfer_dlist_t local_dlist = buildXferDlist(local_tensors); - nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_tensors); + nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); + nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); auto impl = std::make_unique(); - nixl_status_t status = agent_->prepXferDlist( + nixl_status_t status = agent_->createXferReq( toNixlXferOp(op), local_dlist, remote_dlist, @@ -303,14 +349,14 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( impl->xfer_handle); NVF_ERROR( status == NIXL_SUCCESS, - "NIXL prepXferDlist failed with status ", + "NIXL createXferReq failed with status ", static_cast(status)); impl->prepared = true; handle.impl_ = std::move(impl); #else - (void)local_tensors; - (void)remote_tensors; + (void)local_descs; + (void)remote_descs; (void)remote_rank; (void)op; NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); @@ -375,6 +421,7 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + // TODO - check this spin loop NixlXferStatus xfer_status; do { xfer_status = getTransferStatus(handle); @@ -399,10 +446,12 @@ NixlBackend::NixlBackend() NixlBackend& NixlBackend::getInstance() { static auto* instance = new NixlBackend(); + NVF_CHECK(!instance->cleaned_up_, "NIXL backend has been cleaned up"); return *instance; } void NixlBackend::cleanup() { + cleaned_up_ = true; impl_.reset(); } @@ -411,37 +460,45 @@ bool NixlBackend::isAvailable() const { } void NixlBackend::registerTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->registerTensors(tensors); } void NixlBackend::deregisterTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->deregisterTensors(tensors); } void NixlBackend::exchangeMetadata() { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->exchangeMetadata(); } NixlTransferHandle NixlBackend::prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); return impl_->prepareTransfer( - local_tensors, remote_tensors, remote_rank, op); + local_descs, remote_descs, remote_rank, op); } void NixlBackend::postTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->postTransfer(handle); } NixlXferStatus NixlBackend::getTransferStatus( const NixlTransferHandle& handle) const { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); return impl_->getTransferStatus(handle); } void NixlBackend::waitTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->waitTransfer(handle); } -} // namespace nvfuser + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 3b1dedf7cb9..1cc5b84a7b8 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -11,7 +11,9 @@ #include #include #include +#include +#include "exceptions.h" #include "multidevice/communicator.h" #include "visibility.h" @@ -31,6 +33,93 @@ enum class NixlXferStatus { kError, }; +// ------------------------------------------------------------------ +// Todo - those functions should be moved to a more global file +// Helper functions for serializing and deserializing tensors descriptors for TCP store +struct TensorDesc { + uintptr_t addr; + size_t size; + uint32_t dev; +}; +static_assert(std::is_trivially_copyable_v, + "TensorDesc must be trivially copyable for serialization"); + +inline TensorDesc toTensorDesc(const at::Tensor& tensor) { + return { + .addr = reinterpret_cast(tensor.data_ptr()), + .size = static_cast(tensor.numel()) * tensor.element_size(), + .dev = static_cast(tensor.device().index()) + }; +} + +inline at::Tensor fromTensorDesc(const TensorDesc& desc) { + /* + Tensors must be valid on this device + */ + return at::from_blob( + reinterpret_cast(desc.addr), + {static_cast(desc.size)}, + at::TensorOptions().device(at::Device(at::kCUDA, desc.dev)).dtype(at::kByte) + ); +} + +inline std::vector serializeTensorsDescs( + const std::vector& descs) { + size_t count = descs.size(); + std::vector buf(sizeof(count) + count * sizeof(TensorDesc)); + std::memcpy(buf.data(), &count, sizeof(count)); + if (count == 0) + return buf; + + std::memcpy( + buf.data() + sizeof(count), + descs.data(), + descs.size() * sizeof(TensorDesc)); + return buf; +} + +inline std::vector deserializeTensorsDescs( + const std::vector& buf) { + NVF_ERROR(buf.size() >= sizeof(size_t), "Invalid serialized descriptor data"); + size_t count; + std::memcpy(&count, buf.data(), sizeof(count)); + NVF_ERROR( + buf.size() == sizeof(count) + count * sizeof(TensorDesc), + "Corrupted serialized descriptor data"); + + std::vector descs(count); + if (count > 0) { + std::memcpy( + descs.data(), + buf.data() + sizeof(count), + count * sizeof(TensorDesc)); + } + return descs; +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& descs) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + communicator.getTcpStore()->set(key, serializeTensorsDescs(descs)); +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& tensors) { + std::vector descs; + descs.reserve(tensors.size()); + for (const auto& tensor : tensors) { + descs.push_back(toTensorDesc(tensor)); + } + storeTensorDescs(communicator, key, descs); +} + +inline std::vector fetchTensorDescs(Communicator& communicator, const std::string& key) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + auto bytes = communicator.getTcpStore()->get(key); + return deserializeTensorsDescs(bytes); +} + +// End of Todo - those functions should be moved to a more global file +// ------------------------------------------------------------------ + // ------------------------------------------------------------------- // NixlTransferHandle: opaque handle for a prepared transfer // ------------------------------------------------------------------- @@ -49,7 +138,7 @@ class NVF_API NixlTransferHandle { NixlTransferHandle(const NixlTransferHandle&) = delete; NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; - bool isValid() const; + [[nodiscard]] bool isValid() const; private: friend class NixlBackend; @@ -84,7 +173,7 @@ class NixlBackend { // exit (same pattern as Communicator::cleanup). void cleanup(); - bool isAvailable() const; + [[nodiscard]] bool isAvailable() const; // ------------------------------------------------------------------ // Memory registration @@ -114,9 +203,9 @@ class NixlBackend { // All tensors must be contiguous CUDA tensors and previously registered. // The returned handle can be posted multiple times (preparation is // amortized). - NixlTransferHandle prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + [[nodiscard]] NixlTransferHandle prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -124,16 +213,20 @@ class NixlBackend { void postTransfer(NixlTransferHandle& handle); // Poll the status of a posted transfer without blocking. - NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + [[nodiscard]] NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; // Block until the transfer completes (or errors out). void waitTransfer(NixlTransferHandle& handle); private: NixlBackend(); + bool cleaned_up_ = false; class Impl; std::unique_ptr impl_; }; + + + } // namespace nvfuser From 0f2152890df45889c6664f1a03cb4e7f0a12929c Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:44:15 +0200 Subject: [PATCH 06/15] add python build changes for nixl --- python/setup.py | 3 +++ python/utils.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 9d340016d5d..aba47bee0f1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,6 +32,9 @@ # NVFUSER_BUILD_WITH_UCC # Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. # +# NVFUSER_BUILD_WITH_NIXL +# Build nvfuser with NIXL support. You may need to set NIXL_PREFIX to the NIXL install directory. +# # NVFUSER_BUILD_WITHOUT_DISTRIBUTED # Build nvfuser without multidevice support # diff --git a/python/utils.py b/python/utils.py index 272d347c23e..303220ca867 100644 --- a/python/utils.py +++ b/python/utils.py @@ -22,6 +22,7 @@ class BuildConfig: no_benchmark: bool = False no_ninja: bool = False build_with_ucc: bool = False + build_with_nixl: bool = False build_with_asan: bool = False build_without_distributed: bool = False explicit_error_check: bool = False @@ -98,6 +99,12 @@ def parse_args(): action="store_true", help="Build nvfuser with UCC support", ) + parser.add_argument( + "--build-with-nixl", + dest="build_with_nixl", + action="store_true", + help="Build nvfuser with NIXL support", + ) parser.add_argument( "--explicit-error-check", dest="explicit_error_check", @@ -200,6 +207,7 @@ def create_build_config(): no_benchmark=args.no_benchmark, no_ninja=args.no_ninja, build_with_ucc=args.build_with_ucc, + build_with_nixl=args.build_with_nixl, build_with_asan=args.build_with_asan, build_without_distributed=args.build_without_distributed, explicit_error_check=args.explicit_error_check, @@ -245,6 +253,8 @@ def override_build_config_from_env(config): config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") if "NVFUSER_BUILD_WITH_UCC" in os.environ: config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_WITH_NIXL" in os.environ: + config.build_with_nixl = get_env_flag_bool("NVFUSER_BUILD_WITH_NIXL") if "NVFUSER_BUILD_WITH_ASAN" in os.environ: config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: @@ -442,7 +452,11 @@ def cmake(config, relative_path): logger_level = logger.getEffectiveLevel() logger.setLevel(logging.CRITICAL) - pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() + cmake_prefix_path = get_pytorch_cmake_prefix() + llvm_dir = os.environ.get("LLVM_DIR") + if llvm_dir: + cmake_prefix_path += ";" + llvm_dir + pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path logger.setLevel(logger_level) @@ -469,6 +483,7 @@ def on_or_off(flag: bool) -> str: f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", f"-DNVFUSER_BUILD_WITH_ASAN={on_or_off(config.build_with_asan)}", f"-DNVFUSER_STANDALONE_BUILD_WITH_UCC={on_or_off(config.build_with_ucc)}", + f"-DNVFUSER_STANDALONE_BUILD_WITH_NIXL={on_or_off(config.build_with_nixl)}", f"-DNVFUSER_EXPLICIT_ERROR_CHECK={on_or_off(config.explicit_error_check)}", f"-DBUILD_TEST={on_or_off(not config.no_test)}", f"-DBUILD_PYTHON={on_or_off(not config.no_python)}", @@ -480,6 +495,25 @@ def on_or_off(flag: bool) -> str: "-B", cmake_build_dir, ] + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + cmd_str.append(f"-DCUDA_TOOLKIT_ROOT_DIR={cuda_home}") + nvcc_path = os.path.join(cuda_home, "bin", "nvcc") + if os.path.isfile(nvcc_path): + cmd_str.append(f"-DCMAKE_CUDA_COMPILER={nvcc_path}") + cudahostcxx = os.environ.get("CUDAHOSTCXX") + if cudahostcxx: + resolved = shutil.which(cudahostcxx) or cudahostcxx + cmd_str.append(f"-DCMAKE_CUDA_HOST_COMPILER={resolved}") + os.environ["CUDAHOSTCXX"] = resolved + cc = os.environ.get("CC") + if cc: + resolved = shutil.which(cc) or cc + cmd_str.append(f"-DCMAKE_C_COMPILER={resolved}") + cxx = os.environ.get("CXX") + if cxx: + resolved = shutil.which(cxx) or cxx + cmd_str.append(f"-DCMAKE_CXX_COMPILER={resolved}") if config.cutlass_max_jobs: cmd_str.append(f"-DCUTLASS_MAX_JOBS={config.cutlass_max_jobs}") if config.nvmmh_include_dir: From 6144827540126d54452e8fd71e15341a3beb6db1 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:44:56 +0200 Subject: [PATCH 07/15] fix typo --- python/utils.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/python/utils.py b/python/utils.py index 303220ca867..908433ec4cb 100644 --- a/python/utils.py +++ b/python/utils.py @@ -495,25 +495,6 @@ def on_or_off(flag: bool) -> str: "-B", cmake_build_dir, ] - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - cmd_str.append(f"-DCUDA_TOOLKIT_ROOT_DIR={cuda_home}") - nvcc_path = os.path.join(cuda_home, "bin", "nvcc") - if os.path.isfile(nvcc_path): - cmd_str.append(f"-DCMAKE_CUDA_COMPILER={nvcc_path}") - cudahostcxx = os.environ.get("CUDAHOSTCXX") - if cudahostcxx: - resolved = shutil.which(cudahostcxx) or cudahostcxx - cmd_str.append(f"-DCMAKE_CUDA_HOST_COMPILER={resolved}") - os.environ["CUDAHOSTCXX"] = resolved - cc = os.environ.get("CC") - if cc: - resolved = shutil.which(cc) or cc - cmd_str.append(f"-DCMAKE_C_COMPILER={resolved}") - cxx = os.environ.get("CXX") - if cxx: - resolved = shutil.which(cxx) or cxx - cmd_str.append(f"-DCMAKE_CXX_COMPILER={resolved}") if config.cutlass_max_jobs: cmd_str.append(f"-DCUTLASS_MAX_JOBS={config.cutlass_max_jobs}") if config.nvmmh_include_dir: From b32587a19354c9b3f36b0048e9eaf8c0d1ba9f31 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:52:45 +0200 Subject: [PATCH 08/15] restore main: --- csrc/multidevice/cuda_p2p.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 38ae6c549fc..514195c0746 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -10,10 +10,6 @@ #include #include -#include -#include -#include - #include "multidevice/ipc_handle.h" namespace nvfuser { From f8a94fcae21d909d4935306e4e22ff1e5af1d4a0 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 18:04:23 +0200 Subject: [PATCH 09/15] fix bug where zero-length buffer was passed to nixl --- csrc/multidevice/communicator.h | 2 +- csrc/multidevice/nixl.cpp | 44 +++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index c4a1eb3d09b..127276f6cb4 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -131,7 +131,7 @@ class NVF_API Communicator { c10d::TCPStore* getTcpStore() { return store_.get(); - } +} private: diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index fa0ac7c7a94..f71c70b7f4a 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -67,7 +68,7 @@ void validateCudaTensors(const std::vector& tensors) { #ifdef USE_NIXL nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { - nixl_reg_dlist_t dlist(VRAM_SEG, tensors.size()); + nixl_reg_dlist_t dlist(VRAM_SEG); for (const auto& t : tensors) { dlist.addDesc( {reinterpret_cast(t.data_ptr()), @@ -78,7 +79,7 @@ nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { } nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { - nixl_xfer_dlist_t dlist(VRAM_SEG, descs.size()); + nixl_xfer_dlist_t dlist(VRAM_SEG); for (const auto& desc : descs) { dlist.addDesc({desc.addr, desc.size, desc.dev}); } @@ -163,30 +164,47 @@ NixlBackend::Impl::Impl(Communicator& communicator) // local descriptor list for VRAM -- if no backend claims VRAM, the // probe fails and we mark the backend as unavailable. { + constexpr int64_t kProbeBytes = 64; auto probe = at::empty( - {1}, + {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( at::kCUDA, communicator_.deviceId())); - nixl_reg_dlist_t reg_dlist(VRAM_SEG, 1); - reg_dlist.addDesc( - {reinterpret_cast(probe.data_ptr()), - probe.nbytes(), - static_cast(probe.device().index())}); + size_t nbytes = static_cast(probe.nbytes()); + uintptr_t addr = reinterpret_cast(probe.data_ptr()); + uint32_t dev_idx = static_cast(probe.device().index()); + + std::cerr << "[NixlBackend probe] device=" << dev_idx + << " addr=0x" << std::hex << addr << std::dec + << " nbytes=" << nbytes + << " numel=" << probe.numel() + << " element_size=" << probe.element_size() << std::endl; + + NVF_ERROR(nbytes > 0, "NIXL probe: unexpected zero-byte tensor"); + NVF_ERROR(addr != 0, "NIXL probe: null data pointer"); + + nixl_reg_dlist_t reg_dlist(VRAM_SEG); + reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); + + std::cerr << "[NixlBackend probe] reg_dlist desc: addr=0x" << std::hex + << reg_dlist[0].addr << std::dec + << " len=" << reg_dlist[0].len + << " devId=" << reg_dlist[0].devId << std::endl; nixl_status_t reg_status = agent_->registerMem(reg_dlist); + std::cerr << "[NixlBackend probe] registerMem returned " + << reg_status << std::endl; if (reg_status != NIXL_SUCCESS) { return; } - nixl_xfer_dlist_t xfer_dlist(VRAM_SEG, 1); - xfer_dlist.addDesc( - {reinterpret_cast(probe.data_ptr()), - probe.nbytes(), - static_cast(probe.device().index())}); + nixl_xfer_dlist_t xfer_dlist(VRAM_SEG); + xfer_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + std::cerr << "[NixlBackend probe] prepXferDlist returned " + << prep_status << std::endl; if (dlist_handle) { agent_->releasedDlistH(dlist_handle); From a6b6f870737b5e9348d75cf312dd2bab6edd69c1 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 18:13:55 +0200 Subject: [PATCH 10/15] Reduce probe size to 1 --- csrc/multidevice/nixl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index f71c70b7f4a..37881137e40 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -164,7 +164,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) // local descriptor list for VRAM -- if no backend claims VRAM, the // probe fails and we mark the backend as unavailable. { - constexpr int64_t kProbeBytes = 64; + constexpr int64_t kProbeBytes = 1; auto probe = at::empty( {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( From 95460af2cd1857ffa3e5a8dea7d3892c37dcdfbc Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 1 Mar 2026 16:09:47 +0200 Subject: [PATCH 11/15] Address PR comments. --- csrc/multidevice/communicator.cpp | 7 ++++++- csrc/multidevice/communicator.h | 7 ++----- csrc/multidevice/nixl.cpp | 26 +++++++++----------------- csrc/multidevice/nixl.h | 11 ----------- 4 files changed, 17 insertions(+), 34 deletions(-) diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index e4c1c1cc584..c021a129670 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -186,7 +186,8 @@ Communicator::Communicator( master_port_( c10d::TCPStoreOptions::kDefaultPort + 42), // to avoid collision ucc_available_(false), - nccl_available_(false) { + nccl_available_(false), + nixl_available_(false) { if (isOptionDisabled(DisableOption::Multidevice)) { TORCH_WARN( "Multi-device support is disabled. All communication operations will " @@ -239,6 +240,10 @@ Communicator::Communicator( #ifdef USE_C10D_NCCL nccl_available_ = true; #endif + +#ifdef USE_NIXL + nixl_available_ = true; +#endif } namespace { diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 127276f6cb4..f54d0535434 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -120,11 +120,7 @@ class NVF_API Communicator { } else if (backend == CommunicatorBackend::kNccl) { return nccl_available_; } else if (backend == CommunicatorBackend::kNixl) { -#ifdef USE_NIXL - return true; -#else - return false; -#endif + return nixl_available_; } return false; } @@ -159,6 +155,7 @@ class NVF_API Communicator { int master_port_; bool ucc_available_; bool nccl_available_; + bool nixl_available_; // stores the world's store used for the backend init c10::intrusive_ptr store_; // cache for the created backends. The keys are strings generated from Teams diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 37881137e40..06d4939b5fd 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -26,8 +26,15 @@ namespace nvfuser { class NixlTransferHandleImpl { public: #ifdef USE_NIXL - // TODO - is it leaking when handleimpl is destroyed ? + explicit NixlTransferHandleImpl(nixlAgent* agent) : agent(agent) {} + nixlAgent* agent; nixlXferReqH* xfer_handle = nullptr; + + ~NixlTransferHandleImpl() { + if (xfer_handle) { + agent->releaseXferReq(xfer_handle); + } + } #endif bool prepared = false; bool posted = false; @@ -173,26 +180,13 @@ NixlBackend::Impl::Impl(Communicator& communicator) uintptr_t addr = reinterpret_cast(probe.data_ptr()); uint32_t dev_idx = static_cast(probe.device().index()); - std::cerr << "[NixlBackend probe] device=" << dev_idx - << " addr=0x" << std::hex << addr << std::dec - << " nbytes=" << nbytes - << " numel=" << probe.numel() - << " element_size=" << probe.element_size() << std::endl; - NVF_ERROR(nbytes > 0, "NIXL probe: unexpected zero-byte tensor"); NVF_ERROR(addr != 0, "NIXL probe: null data pointer"); nixl_reg_dlist_t reg_dlist(VRAM_SEG); reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); - std::cerr << "[NixlBackend probe] reg_dlist desc: addr=0x" << std::hex - << reg_dlist[0].addr << std::dec - << " len=" << reg_dlist[0].len - << " devId=" << reg_dlist[0].devId << std::endl; - nixl_status_t reg_status = agent_->registerMem(reg_dlist); - std::cerr << "[NixlBackend probe] registerMem returned " - << reg_status << std::endl; if (reg_status != NIXL_SUCCESS) { return; } @@ -203,8 +197,6 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); - std::cerr << "[NixlBackend probe] prepXferDlist returned " - << prep_status << std::endl; if (dlist_handle) { agent_->releasedDlistH(dlist_handle); @@ -358,7 +350,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); - auto impl = std::make_unique(); + auto impl = std::make_unique(agent_.get()); nixl_status_t status = agent_->createXferReq( toNixlXferOp(op), local_dlist, diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 1cc5b84a7b8..e9de9c8384b 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -52,17 +52,6 @@ inline TensorDesc toTensorDesc(const at::Tensor& tensor) { }; } -inline at::Tensor fromTensorDesc(const TensorDesc& desc) { - /* - Tensors must be valid on this device - */ - return at::from_blob( - reinterpret_cast(desc.addr), - {static_cast(desc.size)}, - at::TensorOptions().device(at::Device(at::kCUDA, desc.dev)).dtype(at::kByte) - ); -} - inline std::vector serializeTensorsDescs( const std::vector& descs) { size_t count = descs.size(); From 41ec0ac51decd5d80056dc57a04d840746db888a Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 11:56:41 +0200 Subject: [PATCH 12/15] typos --- csrc/multidevice/communicator.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index f54d0535434..25fed9eebfc 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,7 +11,6 @@ #include #include -#include #include #ifdef NVFUSER_DISTRIBUTED @@ -129,7 +128,6 @@ class NVF_API Communicator { return store_.get(); } - private: Communicator( CommunicatorBackend backend = comm_backend_default, From d63ffd7cd02d800c930b0880c3abb7cedbd6655d Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 12:01:32 +0200 Subject: [PATCH 13/15] set getAgentName to inline --- csrc/multidevice/nixl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 06d4939b5fd..6cbfd737121 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -135,7 +135,7 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: - std::string constructAgentName(int64_t rank); + inline std::string getAgentName(int64_t rank); #ifdef USE_NIXL std::unique_ptr agent_; @@ -153,7 +153,7 @@ class NixlBackend::Impl { NixlBackend::Impl::Impl(Communicator& communicator) : communicator_(communicator) { #ifdef USE_NIXL - std::string agent_name = constructAgentName(communicator_.deviceId()); + std::string agent_name = getAgentName(communicator_.deviceId()); nixlAgentConfig cfg(false); agent_ = std::make_unique(agent_name, cfg); @@ -214,7 +214,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::constructAgentName(int64_t rank){ +std::string NixlBackend::Impl::getAgentName(int64_t rank){ return "rank_" + std::to_string(rank); } @@ -345,7 +345,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( " vs ", remote_descs.size()); - std::string remote_agent_name = constructAgentName(remote_rank); + std::string remote_agent_name = getAgentName(remote_rank); nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); From 86e50288cc42b754656fd90a80627341d1cf4066 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 13:41:15 +0200 Subject: [PATCH 14/15] fix comments in nixl.cpp --- csrc/multidevice/nixl.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 6cbfd737121..7a3b1b173d5 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -325,13 +325,9 @@ void NixlBackend::Impl::exchangeMetadata() { // kRead -- data flows from remote into local // kWrite -- data flows from local into remote // -// remote_tensors are LOCAL tensors whose data_ptr identifies the -// corresponding registration slot. The actual remote addresses are -// looked up from the descriptors exchanged during exchangeMetadata(). -// This requires all ranks to register tensors in the same order. NixlTransferHandle NixlBackend::Impl::prepareTransfer( const std::vector& local_descs, // Local addresses - const std::vector& remote_descs, // Remote tensors (not valid on this rank) + const std::vector& remote_descs, // Remote tensors (cannot be dereferenced on this rank) int64_t remote_rank, NixlXferOp op) { NixlTransferHandle handle; From 7283aa89823f4dedc969cf3e1a742fb25e0069d0 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 14:23:42 +0200 Subject: [PATCH 15/15] clean ifdef USE_NIXL statements --- csrc/multidevice/nixl.cpp | 118 +++++++++++++------------------------- 1 file changed, 41 insertions(+), 77 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 7a3b1b173d5..253298bc9bc 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -111,21 +111,19 @@ nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { // NixlBackend::Impl // =================================================================== +#ifdef USE_NIXL + class NixlBackend::Impl { public: - explicit Impl(Communicator& communicator); + static std::unique_ptr create(Communicator& communicator); ~Impl(); - bool isAvailable() const { - return available_; - } - void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); void exchangeMetadata(); NixlTransferHandle prepareTransfer( - const std::vector& local_descs, + const std::vector& local_descs, const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -135,14 +133,12 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: + explicit Impl(Communicator& communicator); inline std::string getAgentName(int64_t rank); -#ifdef USE_NIXL std::unique_ptr agent_; nixlBackendH* backend_ = nullptr; -#endif Communicator& communicator_; - bool available_ = false; bool metadata_exchanged_ = false; }; @@ -151,16 +147,21 @@ class NixlBackend::Impl { // ------------------------------------------------------------------- NixlBackend::Impl::Impl(Communicator& communicator) - : communicator_(communicator) { -#ifdef USE_NIXL - std::string agent_name = getAgentName(communicator_.deviceId()); + : communicator_(communicator) {} + +std::unique_ptr NixlBackend::Impl::create( + Communicator& communicator) { + std::unique_ptr impl(new Impl(communicator)); + + std::string agent_name = impl->getAgentName(communicator.deviceId()); nixlAgentConfig cfg(false); - agent_ = std::make_unique(agent_name, cfg); + impl->agent_ = std::make_unique(agent_name, cfg); nixl_b_params_t params; - nixl_status_t status = agent_->createBackend("UCX", params, backend_); + nixl_status_t status = + impl->agent_->createBackend("UCX", params, impl->backend_); if (status != NIXL_SUCCESS) { - agent_.reset(); + impl->agent_.reset(); NVF_THROW("Failed to create UCX backend for NIXL agent"); } @@ -175,7 +176,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) auto probe = at::empty( {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( - at::kCUDA, communicator_.deviceId())); + at::kCUDA, communicator.deviceId())); size_t nbytes = static_cast(probe.nbytes()); uintptr_t addr = reinterpret_cast(probe.data_ptr()); uint32_t dev_idx = static_cast(probe.device().index()); @@ -186,9 +187,9 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixl_reg_dlist_t reg_dlist(VRAM_SEG); reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); - nixl_status_t reg_status = agent_->registerMem(reg_dlist); + nixl_status_t reg_status = impl->agent_->registerMem(reg_dlist); if (reg_status != NIXL_SUCCESS) { - return; + return nullptr; } nixl_xfer_dlist_t xfer_dlist(VRAM_SEG); @@ -196,25 +197,24 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = - agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + impl->agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); if (dlist_handle) { - agent_->releasedDlistH(dlist_handle); + impl->agent_->releasedDlistH(dlist_handle); } - agent_->deregisterMem(reg_dlist); + impl->agent_->deregisterMem(reg_dlist); if (prep_status != NIXL_SUCCESS) { - return; + return nullptr; } } - available_ = true; -#endif + return impl; } NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::getAgentName(int64_t rank){ +std::string NixlBackend::Impl::getAgentName(int64_t rank) { return "rank_" + std::to_string(rank); } @@ -225,8 +225,6 @@ std::string NixlBackend::Impl::getAgentName(int64_t rank){ // TODO - consider adding RAII wrapper void NixlBackend::Impl::registerTensors( const std::vector& tensors) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); validateCudaTensors(tensors); nixl_reg_dlist_t dlist = buildRegDlist(tensors); @@ -237,16 +235,10 @@ void NixlBackend::Impl::registerTensors( static_cast(status)); metadata_exchanged_ = false; -#else - (void)tensors; - NVF_THROW("NIXL support not compiled"); -#endif } void NixlBackend::Impl::deregisterTensors( const std::vector& tensors) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); validateCudaTensors(tensors); nixl_reg_dlist_t dlist = buildRegDlist(tensors); @@ -257,10 +249,6 @@ void NixlBackend::Impl::deregisterTensors( static_cast(status)); metadata_exchanged_ = false; -#else - (void)tensors; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -268,9 +256,6 @@ void NixlBackend::Impl::deregisterTensors( // ------------------------------------------------------------------- void NixlBackend::Impl::exchangeMetadata() { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); - nixl_blob_t local_md; nixl_status_t md_status = agent_->getLocalMD(local_md); NVF_ERROR( @@ -291,7 +276,7 @@ void NixlBackend::Impl::exchangeMetadata() { if (rank == my_rank) { continue; } - // Fetch & load MD + // Fetch & load MD auto bytes = store->get(md_key_prefix + std::to_string(rank)); nixl_blob_t remote_md(bytes.begin(), bytes.end()); std::string remote_agent_name; @@ -309,9 +294,6 @@ void NixlBackend::Impl::exchangeMetadata() { store->deleteKey(md_key_prefix + std::to_string(my_rank)); metadata_exchanged_ = true; -#else - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -326,13 +308,10 @@ void NixlBackend::Impl::exchangeMetadata() { // kWrite -- data flows from local into remote // NixlTransferHandle NixlBackend::Impl::prepareTransfer( - const std::vector& local_descs, // Local addresses - const std::vector& remote_descs, // Remote tensors (cannot be dereferenced on this rank) + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op) { - NixlTransferHandle handle; -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); NVF_ERROR( local_descs.size() == remote_descs.size(), @@ -359,14 +338,8 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( static_cast(status)); impl->prepared = true; + NixlTransferHandle handle; handle.impl_ = std::move(impl); -#else - (void)local_descs; - (void)remote_descs; - (void)remote_rank; - (void)op; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif return handle; } @@ -375,8 +348,6 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( // ------------------------------------------------------------------- void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); NVF_ERROR( !handle.impl_->posted, @@ -389,10 +360,6 @@ void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { static_cast(status)); handle.impl_->posted = true; -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -401,8 +368,6 @@ void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { NixlXferStatus NixlBackend::Impl::getTransferStatus( const NixlTransferHandle& handle) const { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); @@ -415,15 +380,9 @@ NixlXferStatus NixlBackend::Impl::getTransferStatus( default: return NixlXferStatus::kError; } -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); @@ -437,18 +396,23 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { } while (xfer_status == NixlXferStatus::kInProgress); handle.impl_->posted = false; -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } +#else // !USE_NIXL + +class NixlBackend::Impl {}; + +#endif // USE_NIXL + // =================================================================== // NixlBackend singleton + public API // =================================================================== -NixlBackend::NixlBackend() - : impl_(std::make_unique(Communicator::getInstance())) {} +NixlBackend::NixlBackend() { +#ifdef USE_NIXL + impl_ = Impl::create(Communicator::getInstance()); +#endif +} NixlBackend& NixlBackend::getInstance() { static auto* instance = new NixlBackend(); @@ -462,7 +426,7 @@ void NixlBackend::cleanup() { } bool NixlBackend::isAvailable() const { - return impl_ && impl_->isAvailable(); + return impl_ != nullptr; } void NixlBackend::registerTensors(const std::vector& tensors) {