From d1dd491e3329f48cc84db2184ab2a69ab7fe563a Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 10 Feb 2026 16:48:40 -0800 Subject: [PATCH 01/10] test case and python bindings --- python/python_direct/ir.cpp | 21 +++++++++++++++ .../python/multidevice/test_communication.py | 26 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index 93c032e4ef6..c149663da38 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -501,6 +501,27 @@ Returns ------- TensorView A TensorView with the swizzled axes in its loop domain. +)") + .def( + "swizzle1d", + [](TensorView* self, int64_t x, ParallelType parallel_type) { + return self->swizzle1d(x, parallel_type); + }, + py::return_value_policy::reference, + py::arg("x"), + py::arg("parallel_type"), + R"( +Swizzle the specified axis with the device index corresponding to the given parallel type. +Parameters +---------- +x : int +The axis to swizzle. +parallel_type : ParallelType +The device parallel type for the 1D swizzle. +Returns +------- +TensorView +A TensorView with the swizzled axis in its loop domain. )") .def( "rfactor", diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 833ab511ff3..cd2c8f9e550 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,3 +171,29 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +def test_collective_permute(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Float) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) + out_tv.axis(0).parallelize(nvfuser.ParallelType.Stream) + + inp_ref = torch.randn(d * 3) + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + with torch.profiler.profile() as prof: + (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) + print(prof.key_averages()) + torch.testing.assert_close(out.cpu(), inp_ref) From 363a0f458b458c27b7ab920c7a6a4c54c7e8baed Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 11 Feb 2026 14:05:17 -0800 Subject: [PATCH 02/10] add CollectivePermute wip --- csrc/dispatch.h | 1 + csrc/host_ir/evaluator.cpp | 32 ++++++ csrc/host_ir/evaluator.h | 1 + csrc/host_ir/lower_to_communication.cpp | 47 +++++++- csrc/host_ir/lowering.cpp | 24 +++-- csrc/multidevice/communication.cpp | 100 ++++++++++++++++++ csrc/multidevice/communication.h | 66 ++++++++++++ csrc/multidevice/utils.cpp | 19 ++++ csrc/multidevice/utils.h | 6 ++ .../python/multidevice/test_communication.py | 2 +- 10 files changed, 282 insertions(+), 16 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 7c53c86d903..6ef541a8037 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -125,6 +125,7 @@ class Val; f(SdpaFwdOp); \ f(SdpaBwdOp); \ f(EmbeddingFwdOp); \ + f(CollectivePermute); \ f(Communication); \ f(P2PCommunication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2396767d5b0..e299bc9649d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -310,6 +310,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) { ipc_handle_cache_.exchangeHandles(share_mem_handles->communications()); } +void HostIrEvaluator::handle(CollectivePermute* communication) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0)); + at::Tensor output_tensor = + getKnownTensorOrUndefined(communication->output(0)); + +#ifndef NDEBUG + validateSizesAndStrides( + {input_tensor, output_tensor}, + {communication->in(), communication->out()}, + expr_evaluator_); +#endif + + CommunicatorBackend backend_type = communication->backend(); + // CollectivePermute is only supported with NCCL backend because + // UCC does not support coalescing. + NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl); + c10d::Backend* backend = + communicator_->getBackendForTeam(communication->team(), backend_type); + works_[communication] = postSingleCommunication( + communication, + communicator_->deviceId(), + backend, + input_tensor, + output_tensor, + expr_evaluator_.evaluate(communication->sendPeer()).as(), + expr_evaluator_.evaluate(communication->recvPeer()).as()); +} + void HostIrEvaluator::handle(Communication* communication) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..9acbb03750a 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(Synchronize*) override; void handle(PostOnStream*) override; void handle(LaunchKernel*) override; + void handle(CollectivePermute*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; void handle(Wait*) override; diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 6bd394f8b7f..a703bea6f91 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -345,6 +345,33 @@ void lowerToAllToAll( backend)); } +void lowerToCollectivePermute( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root, + DeviceIdxType my_device_idx) { + NVF_ERROR_EQ( + input_tv->getDeviceMesh(), + output_tv->getDeviceMesh(), + "CollectivePermute sender and receiver meshes must be the same. Given ", + input_tv->getDeviceMesh(), + " and ", + output_tv->getDeviceMesh()); + + IterDomain* stream_id = + getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop); + Swizzle1D* swizzle = stream_id->definition()->as(); + ParallelType pt = swizzle->parallelType(); + + const auto& [send_peer, recv_peer] = + dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); + Team team = input_tv->getDeviceMesh().vector(); + comms.push_back(IrBuilder::create( + output_tv, input_tv, team, send_peer, recv_peer, backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::unordered_set logical_ids = getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); @@ -422,7 +449,6 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // Check if we are going from DIDx -> Stream, which is a ring allgather. // This can be executed as a broadcast or send recvs, which is decided // by the presence of a swizzle in the stream id definition. - // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { IterDomain* c_stream_logical_id = getLogicalFromLoopId(consumer, c_stream_id); @@ -431,10 +457,11 @@ CommunicationInfo getCommunicationInfo(Expr* e) { same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh."); - fill_communication_info( - CommunicationType::StreamBroadcast, - p_logical_id, - c_stream_logical_id); + auto* swizzle = dynamic_cast(c_stream_id->definition()); + CommunicationType type = swizzle != nullptr + ? CommunicationType::CollectivePermute + : CommunicationType::StreamBroadcast; + fill_communication_info(type, p_logical_id, c_stream_logical_id); continue; } } @@ -525,6 +552,7 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || + type == CommunicationType::CollectivePermute || type == CommunicationType::AllToAll || type == CommunicationType::StreamBroadcast) { return layout; @@ -660,6 +688,15 @@ std::vector convertSingleOpToCommunication( "StreamBroadcast requires a root value passed in through lowering"); lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); break; + case CommunicationType::CollectivePermute: + // FIXME: Rename this to host loop index. Collective Permute has no root. + // The send and recv peer indices are computed using the host loop index. + NVF_ERROR( + root != nullptr, + "CollectivePermute requires a root value passed in through lowering"); + lowerToCollectivePermute( + input_tv, output_tv, backend, comms, root, my_device_idx); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 01f271ea636..0f36d4f65db 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -181,13 +181,19 @@ void lowerSegment( for (Expr* c : convertSingleOpToCommunication( e, device_id, innermost.loop->index())) { NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", + c->isA() || c->isA(), + "Exprs in a Communication group should be Communication or " + "CollectivePermute: ", c); - auto* communication = c->as(); - TensorView* in = communication->in(); - TensorView* out = communication->out(); - if (communication->type() != CommunicationType::StreamBroadcast && + TensorView* in = c->input(0)->as(); + TensorView* out = c->output(0)->as(); + bool can_shard_in = true; + if (c->isA() || + c->as()->type() == + CommunicationType::StreamBroadcast) { + can_shard_in = false; + } + if (can_shard_in && haveDifferentShardings( in, DomainType::kAllocation, @@ -196,8 +202,7 @@ void lowerSegment( {ParallelType::Stream})) { Val*& sharded_in = replacement_map[in]; if (sharded_in == nullptr) { - sharded_in = - hir::shardByStream(in, innermost.loop->index(), communication); + sharded_in = hir::shardByStream(in, innermost.loop->index(), c); innermost_scope.pushBack(sharded_in->definition()); } } @@ -213,8 +218,7 @@ void lowerSegment( innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); auto [i, inserted] = replacement_map.emplace( - out, - hir::shardByStream(out, innermost.loop->index(), communication)); + out, hir::shardByStream(out, innermost.loop->index(), c)); NVF_ERROR(inserted, "The input segmented fusion should be SSA."); innermost_scope.pushBack(i->second->definition()); } else { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 21bc4994628..5b432f8876d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -60,6 +60,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::StreamBroadcast: os << "StreamBroadcast"; break; + case CommunicationType::CollectivePermute: + os << "CollectivePermute"; + break; } return os; } @@ -158,6 +161,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: case CommunicationType::AllToAll: + case CommunicationType::CollectivePermute: return false; } std::unreachable(); @@ -176,6 +180,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::SendRecv: case CommunicationType::AllToAll: case CommunicationType::StreamBroadcast: + case CommunicationType::CollectivePermute: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -326,6 +331,47 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +CollectivePermute::CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend) + : Expr(passkey) { + NVF_ERROR( + in->getDeviceMesh().size() > 0, + "The input mesh size must be greater than 0."); + NVF_ERROR( + out->getDeviceMesh().size() > 0, + "The output mesh size must be greater than 0."); + addInput(in); + addInput(send_peer); + addInput(recv_peer); + addOutput(out); + addDataAttribute(CommunicationType::CollectivePermute); + addDataAttribute(team); + addDataAttribute(backend); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(CollectivePermute) + +std::string CollectivePermute::toInlineString(const int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "CollectivePermute " << name() << " (" + << "team=(" << team() << ")" + << ", send_peer=" << sendPeer()->toInlineString() + << ", recv_peer=" << recvPeer()->toInlineString() + << ", input=" << in() << ", output=" << out() + << ", backend=" << backend() << ")"; + return ss.str(); +} + +std::string CollectivePermute::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, @@ -650,6 +696,28 @@ c10::intrusive_ptr postAllToAll( empty_split_sizes, /*options=*/{}); } + +c10::intrusive_ptr postCollectivePermute( + CollectivePermute* communication, + DeviceIdxType my_device_index, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + backend->startCoalescing(); + std::vector send_tensors = {input_tensor}; + backend->send( + send_tensors, + send_peer_index, + /*tag=*/0); + std::vector recv_tensors = {output_tensor}; + backend->recv( + recv_tensors, + recv_peer_index, + /*tag=*/0); + return backend->endCoalescing(); +} } // namespace c10::intrusive_ptr postSingleCommunication( @@ -746,6 +814,38 @@ c10::intrusive_ptr postSingleCommunication( } } +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index) { + const Team& team = communication->team(); + if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { + return nullptr; + } + NVF_CHECK(backend != nullptr); + + if (isDebugDumpEnabled(DebugDumpOption::Communication)) { + debug() << "Posting " << communication->toInlineString() + << " with input_tensor " << input_tensor.sizes() + << " and output_tensor " << output_tensor.sizes() + << " send_peer=" << send_peer_index + << " recv_peer=" << recv_peer_index << std::endl; + } + + return postCollectivePermute( + communication, + my_device_index, + send_peer_index, + recv_peer_index, + backend, + input_tensor, + output_tensor); +} + namespace { c10::intrusive_ptr postSend( diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index ab864c3a44a..d1bfa6d6cb6 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -35,6 +35,7 @@ enum class CommunicationType { SendRecv, AllToAll, StreamBroadcast, + CollectivePermute, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -130,6 +131,62 @@ class Communication : public Expr { void validate(); }; +// CollectivePermute: send to send_peer, recv from recv_peer. Separate from +// Communication (no root, no reduce op). Layout: inputs [in, send_peer, +// recv_peer], output [out], attributes [type, team, backend]. +class CollectivePermute : public Expr { + public: + using Expr::Expr; + + CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + CollectivePermute(const CollectivePermute& other) = delete; + CollectivePermute& operator=(const CollectivePermute& other) = delete; + CollectivePermute(CollectivePermute&& other) = delete; + CollectivePermute& operator=(CollectivePermute&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "CollectivePermute"; + } + + CommunicationType type() const { + return attribute(0); + } + + TensorView* in() const { + return input(0)->as(); + } + TensorView* out() const { + return output(0)->as(); + } + Val* sendPeer() const { + return input(1); + } + Val* recvPeer() const { + return input(2); + } + const Team& team() const { + return attribute(1); + } + int64_t team_size() const { + return static_cast(team().size()); + } + CommunicatorBackend backend() const { + return attribute(2); + } +}; + enum class P2PCommunicationType { SEND, RECV }; std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); @@ -246,6 +303,15 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor output_tensor, DeviceIdxType root_index = -1); +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index); + c10::intrusive_ptr postSingleCommunication( P2PCommunication* communication, DeviceIdxType my_device_index, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index beb7283c5a1..1b08456d753 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -14,8 +14,10 @@ #include #include "compute_at_map.h" +#include "ir/builder.h" #include "ir/internal_base_nodes.h" #include "ir/internal_nodes.h" +#include "ops/arith.h" #include "transform_replay.h" #include "type.h" @@ -355,4 +357,21 @@ int64_t getRFactorDeviceDimensionIndex(const TensorView* tv) { return rfactor_did_idx; } +std::pair dispatchSwizzle1D( + Val* host_loop_index, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh) { + int64_t team_size = mesh.size(pt); + at::Tensor md_index = mesh.multiDimensionalIndexOf(device_id); + auto pt_axis = mesh.parallelTypeToAxis(pt); + int64_t team_index = md_index[pt_axis].item(); + Val* team_size_val = IrBuilder::create(team_size, DataType::Index); + Val* team_index_val = IrBuilder::create(team_index, DataType::Index); + return std::make_pair( + mod(add(host_loop_index, team_index_val), team_size_val), + mod(add(team_size_val, sub(team_index_val, host_loop_index)), + team_size_val)); +} + } // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index e924e7fcc75..bad7730fee1 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -91,4 +91,10 @@ bool isValidDeviceSplit(Expr* expr); // See tests/python/test_multidevice.py/test_matmul_allreduce_loop_split int64_t getRFactorDeviceDimensionIndex(const TensorView* tv); +std::pair dispatchSwizzle1D( + Val* my_rank, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh); + } // namespace nvfuser diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index cd2c8f9e550..2d4676823f3 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -189,7 +189,7 @@ def test_collective_permute(multidevice_test): out_tv.set_device_mesh(mesh) out_tv.outer_split(0, d) out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) - out_tv.axis(0).parallelize(nvfuser.ParallelType.Stream) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) inp_ref = torch.randn(d * 3) inp = multidevice_test.shard_tensor(inp_ref, inp_tv) From 4aa8346df7fce0c61a508af21bcd289705be41f3 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 11 Feb 2026 22:44:06 -0800 Subject: [PATCH 03/10] replay swizzle1d in logical to alloc traversal --- csrc/host_ir/ir.cpp | 10 +++-- csrc/host_ir/lower_to_communication.cpp | 2 +- csrc/host_ir/lowering.cpp | 45 ++++++++++++++++--- csrc/multidevice/communication.cpp | 8 +++- csrc/tensor_metadata.cpp | 26 +++++++++++ .../python/multidevice/test_communication.py | 11 ++++- 6 files changed, 88 insertions(+), 14 deletions(-) diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 198601355fb..bc16a39931d 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -257,9 +257,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) this, "must be registered in a HostIrContainer"); NVF_ERROR( - (expr->isOneOf()), - expr, - " must be a Communication, a P2PCommunication, or a EndCoalescing"); + (expr->isOneOf< + Communication, + CollectivePermute, + P2PCommunication, + EndCoalescing>()), + "Got: ", + expr); } NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index a703bea6f91..ae12472cce4 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -365,7 +365,7 @@ void lowerToCollectivePermute( Swizzle1D* swizzle = stream_id->definition()->as(); ParallelType pt = swizzle->parallelType(); - const auto& [send_peer, recv_peer] = + const auto& [recv_peer, send_peer] = dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); Team team = input_tv->getDeviceMesh().vector(); comms.push_back(IrBuilder::create( diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 0f36d4f65db..ea803151b15 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -12,8 +12,10 @@ #include "host_ir/ir.h" #include "host_ir/lower_to_communication.h" #include "host_ir/ops.h" +#include "ir/builder.h" #include "ir/iostream.h" #include "ir/utils.h" +#include "kernel_ir.h" #include "multidevice/propagation.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" @@ -148,6 +150,24 @@ Expr* cloneWithNewOperands( return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); } +// If all allocation domain extents of tv are constant, returns a new constant +// Val for the total size. Otherwise returns nullptr. Using a constant size +// makes the Allocate independent of the loop index so it is not invalidated +// when the index changes in the evaluator. +Val* getConstantAllocationSizeIfAvailable(TensorView* tv) { + const auto* domain = tv->domain(); + int64_t size = 1; + for (IterDomain* axis : + domain->maybeAllocation() | TensorDomain::kNoReductions) { + Val* extent = axis->extent(); + if (!extent->isConst()) { + return nullptr; + } + size *= extent->evaluate().as(); + } + return IrBuilder::create(size, DataType::Index); +} + void lowerSegment( const SegmentedGroup& group, const AliasInfoMap& aliases, @@ -207,9 +227,14 @@ void lowerSegment( } } - // Allocate the recv buffers of communications - auto* allocate = - IrBuilder::create(out, out->getMemoryType()); + // Allocate the recv buffers of communications. Use a constant size + // when all extents are constant so the Allocate is independent of the + // loop index and not invalidated when it changes. + Val* constant_size = getConstantAllocationSizeIfAvailable(out); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out, out->getMemoryType(), constant_size) + : IrBuilder::create(out, out->getMemoryType()); if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kLoop) != nullptr && getShardedIterDomain( @@ -314,8 +339,11 @@ void lowerSegment( if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kAllocation) == nullptr) { - auto* allocate = - IrBuilder::create(out, out->getMemoryType()); + Val* constant_size = getConstantAllocationSizeIfAvailable(out); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out, out->getMemoryType(), constant_size) + : IrBuilder::create(out, out->getMemoryType()); innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, @@ -351,8 +379,11 @@ void lowerSegment( " must not be an alias, got ", alias); - auto* allocate = - IrBuilder::create(out_tv, out_tv->getMemoryType()); + Val* constant_size = getConstantAllocationSizeIfAvailable(out_tv); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out_tv, out_tv->getMemoryType(), constant_size) + : IrBuilder::create(out_tv, out_tv->getMemoryType()); innermost_scope.pushBack(allocate); } diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 33fa785157a..78521149cab 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -842,6 +842,11 @@ c10::intrusive_ptr postCollectivePermute( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { + if (my_device_index == send_peer_index && + my_device_index == recv_peer_index) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; + } backend->startCoalescing(); std::vector send_tensors = {input_tensor}; backend->send( @@ -965,7 +970,8 @@ c10::intrusive_ptr postSingleCommunication( } NVF_CHECK(backend != nullptr); - if (isDebugDumpEnabled(DebugDumpOption::Communication)) { + if (isDebugDumpEnabled(DebugDumpOption::Communication) && + my_device_index == 0) { debug() << "Posting " << communication->toInlineString() << " with input_tensor " << input_tensor.sizes() << " and output_tensor " << output_tensor.sizes() diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 676e27e8805..e66b3b8e793 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -95,11 +95,24 @@ class ForwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = active_ids_.find(in); + auto [in_size, in_stride] = in_it->second; + NVF_ERROR(active_ids_.erase(in) == 1); + NVF_ERROR( + active_ids_.emplace(out, std::make_pair(in_size, in_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -190,11 +203,24 @@ class BackwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto out_it = active_ids_.find(out); + auto [out_size, out_stride] = out_it->second; + NVF_ERROR(active_ids_.erase(out) == 1); + NVF_ERROR( + active_ids_.emplace(in, std::make_pair(out_size, out_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 2d4676823f3..fa5b88178a6 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -195,5 +195,12 @@ def test_collective_permute(multidevice_test): inp = multidevice_test.shard_tensor(inp_ref, inp_tv) with torch.profiler.profile() as prof: (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) - print(prof.key_averages()) - torch.testing.assert_close(out.cpu(), inp_ref) + if multidevice_test.rank == 0: + print("\nOriginal input: ", inp_ref) + + print("\nOutput: ", out) + # torch.testing.assert_close(out.cpu(), inp_ref) + # collective_permute_events = [ + # event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + # ] + # assert len(collective_permute_events) == (d - 1) From 30cd1f25529470e812a5cebe80bbfe271ab200d5 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 12 Feb 2026 15:26:44 -0800 Subject: [PATCH 04/10] working allgather example --- csrc/host_ir/evaluator.cpp | 95 +++++++++++++++++++ csrc/host_ir/lowering.cpp | 43 ++------- .../python/multidevice/test_communication.py | 14 +-- 3 files changed, 107 insertions(+), 45 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index bedf802b41d..99851e21db4 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -564,6 +564,101 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + // This is not ideal. In lowering, we create communication expr. + // The collective permute has the output tensorview, and input vals of + // send_peer and recv_peer. While the definition of output_tv is not + // modified and remains `set`, this output_tv is a use of the vals Even + // though we shardByStream, the use of vals is not modified and has a + // dependency on T1. Cloned e: T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // c: CollectivePermute 77 (team=(0), send_peer=( ( 1 + ( 0 - i140 ) ) % 1 + // ), recv_peer=( ( i140 + 0 ) % 1 ), input=T0_g_float[ideviceIdx.x2{1}, + // iS3{3}] (DeviceMesh{0}), output=T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}), backend=NCCL) + + // %HostIrContainer { (T0_g_float[ideviceIdx.x2{1}, iS3{3}] + // (DeviceMesh{0})) -> (T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})) + // : + // T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) = + // ALLOCATE(buffer=T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // mem_type=global, size=3, zero_init=false, resets_to_zero=false) Stream + // 0x281a6c60 = GetCurrentStream() FOR i140 from 0 to 1: + // SetCurrentStream(Stream i140) + // Synchronize(Stream 0x281a6c60) + // T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) CollectivePermute 82 (team=(0), send_peer=( ( 1 + + // ( 0 - i140 ) ) % 1 ), recv_peer=( ( i140 + 0 ) % 1 ), + // input=T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // output=T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}), + // backend=NCCL) Wait(Communication 82) + // SetCurrentStream(Stream 0x281a6c60) + // FOR i140 from 0 to 1: + // Synchronize(Stream i140) + // } // %HostIrContainer + + // Invalidating index: i140 + // allConsumerValsOf(i140) + // Visited val: i140 + // Consumer of i140: i163 definition: i163 = 0 - i140; + + // Visited val: i163 + // Consumer of i163: i165 definition: i165 = 1 + i163; + + // Visited val: i165 + // Consumer of i165: i167 definition: i167 = i165 % 1; + + // Visited val: i167 + // Consumer of i167: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Visited val: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // Consumer of i167: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Visited val: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // Consumer of i140: i169 definition: i169 = i140 + 0; + + // Visited val: i169 + // Consumer of i169: i171 definition: i171 = i169 % 1; + + // Visited val: i171 + // Consumer of i171: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Consumer of i171: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Consumer of i140: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // consumer_vals: 8 + // Invalidating consumer: i169 + // Invalidating consumer: T2_l_float[istreamIdx10{1}, iS9{3}] + // (DeviceMesh{0}) Invalidating consumer: T1_g_float[istreamIdx6{1}, + // iS5{3}] (DeviceMesh{0}) Invalidating consumer: i167 Invalidating + // consumer: i165 Invalidating consumer: i163 Invalidating consumer: i171 + // Invalidating consumer: i140 + + expr_evaluator_.invalidate(for_loop->index()); + for (auto consumer : allConsumerValsOf(for_loop->index())) { + if (!consumer->isA()) { + expr_evaluator_.invalidate(consumer); + } + } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index ea803151b15..7c6a18865b1 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -150,24 +150,6 @@ Expr* cloneWithNewOperands( return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); } -// If all allocation domain extents of tv are constant, returns a new constant -// Val for the total size. Otherwise returns nullptr. Using a constant size -// makes the Allocate independent of the loop index so it is not invalidated -// when the index changes in the evaluator. -Val* getConstantAllocationSizeIfAvailable(TensorView* tv) { - const auto* domain = tv->domain(); - int64_t size = 1; - for (IterDomain* axis : - domain->maybeAllocation() | TensorDomain::kNoReductions) { - Val* extent = axis->extent(); - if (!extent->isConst()) { - return nullptr; - } - size *= extent->evaluate().as(); - } - return IrBuilder::create(size, DataType::Index); -} - void lowerSegment( const SegmentedGroup& group, const AliasInfoMap& aliases, @@ -194,6 +176,7 @@ void lowerSegment( // If a value is already cloned, IrCloner::clone returns the cloned value // without cloning the value again. Expr* e = ir_cloner.clone(group.exprs().front()); + debug() << "Cloned e: " << e << std::endl; // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. @@ -227,14 +210,8 @@ void lowerSegment( } } - // Allocate the recv buffers of communications. Use a constant size - // when all extents are constant so the Allocate is independent of the - // loop index and not invalidated when it changes. - Val* constant_size = getConstantAllocationSizeIfAvailable(out); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out, out->getMemoryType(), constant_size) - : IrBuilder::create(out, out->getMemoryType()); + auto* allocate = + IrBuilder::create(out, out->getMemoryType()); if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kLoop) != nullptr && getShardedIterDomain( @@ -339,11 +316,8 @@ void lowerSegment( if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kAllocation) == nullptr) { - Val* constant_size = getConstantAllocationSizeIfAvailable(out); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out, out->getMemoryType(), constant_size) - : IrBuilder::create(out, out->getMemoryType()); + auto* allocate = + IrBuilder::create(out, out->getMemoryType()); innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, @@ -379,11 +353,8 @@ void lowerSegment( " must not be an alias, got ", alias); - Val* constant_size = getConstantAllocationSizeIfAvailable(out_tv); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out_tv, out_tv->getMemoryType(), constant_size) - : IrBuilder::create(out_tv, out_tv->getMemoryType()); + auto* allocate = + IrBuilder::create(out_tv, out_tv->getMemoryType()); innermost_scope.pushBack(allocate); } diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index fa5b88178a6..be7d28aef43 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -195,12 +195,8 @@ def test_collective_permute(multidevice_test): inp = multidevice_test.shard_tensor(inp_ref, inp_tv) with torch.profiler.profile() as prof: (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) - if multidevice_test.rank == 0: - print("\nOriginal input: ", inp_ref) - - print("\nOutput: ", out) - # torch.testing.assert_close(out.cpu(), inp_ref) - # collective_permute_events = [ - # event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name - # ] - # assert len(collective_permute_events) == (d - 1) + torch.testing.assert_close(out.cpu(), inp_ref) + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1) From e5ed38f0da92bb969d4e7a844ecb9df312fa77de Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 20 Feb 2026 22:26:47 -0800 Subject: [PATCH 05/10] add send and recv peer definition chains --- csrc/host_ir/evaluator.cpp | 98 -------------------------------------- csrc/host_ir/lowering.cpp | 13 ++++- 2 files changed, 12 insertions(+), 99 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 99851e21db4..9b2ddfb6373 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -564,101 +564,6 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { - // This is not ideal. In lowering, we create communication expr. - // The collective permute has the output tensorview, and input vals of - // send_peer and recv_peer. While the definition of output_tv is not - // modified and remains `set`, this output_tv is a use of the vals Even - // though we shardByStream, the use of vals is not modified and has a - // dependency on T1. Cloned e: T1_g_float[istreamIdx6{1}, iS5{3}] - // (DeviceMesh{0}) - // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), - // cache_op=Streaming ) - - // c: CollectivePermute 77 (team=(0), send_peer=( ( 1 + ( 0 - i140 ) ) % 1 - // ), recv_peer=( ( i140 + 0 ) % 1 ), input=T0_g_float[ideviceIdx.x2{1}, - // iS3{3}] (DeviceMesh{0}), output=T1_g_float[istreamIdx6{1}, iS5{3}] - // (DeviceMesh{0}), backend=NCCL) - - // %HostIrContainer { (T0_g_float[ideviceIdx.x2{1}, iS3{3}] - // (DeviceMesh{0})) -> (T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})) - // : - // T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) = - // ALLOCATE(buffer=T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), - // mem_type=global, size=3, zero_init=false, resets_to_zero=false) Stream - // 0x281a6c60 = GetCurrentStream() FOR i140 from 0 to 1: - // SetCurrentStream(Stream i140) - // Synchronize(Stream 0x281a6c60) - // T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = - // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), - // stream_index=i140) CollectivePermute 82 (team=(0), send_peer=( ( 1 + - // ( 0 - i140 ) ) % 1 ), recv_peer=( ( i140 + 0 ) % 1 ), - // input=T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), - // output=T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}), - // backend=NCCL) Wait(Communication 82) - // SetCurrentStream(Stream 0x281a6c60) - // FOR i140 from 0 to 1: - // Synchronize(Stream i140) - // } // %HostIrContainer - - // Invalidating index: i140 - // allConsumerValsOf(i140) - // Visited val: i140 - // Consumer of i140: i163 definition: i163 = 0 - i140; - - // Visited val: i163 - // Consumer of i163: i165 definition: i165 = 1 + i163; - - // Visited val: i165 - // Consumer of i165: i167 definition: i167 = i165 % 1; - - // Visited val: i167 - // Consumer of i167: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) - // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) - // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), - // cache_op=Streaming ) - - // Visited val: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) - // Consumer of i167: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) - // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = - // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), - // stream_index=i140) - - // Visited val: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) - // Consumer of i140: i169 definition: i169 = i140 + 0; - - // Visited val: i169 - // Consumer of i169: i171 definition: i171 = i169 % 1; - - // Visited val: i171 - // Consumer of i171: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) - // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) - // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), - // cache_op=Streaming ) - - // Consumer of i171: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) - // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = - // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), - // stream_index=i140) - - // Consumer of i140: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) - // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = - // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), - // stream_index=i140) - - // consumer_vals: 8 - // Invalidating consumer: i169 - // Invalidating consumer: T2_l_float[istreamIdx10{1}, iS9{3}] - // (DeviceMesh{0}) Invalidating consumer: T1_g_float[istreamIdx6{1}, - // iS5{3}] (DeviceMesh{0}) Invalidating consumer: i167 Invalidating - // consumer: i165 Invalidating consumer: i163 Invalidating consumer: i171 - // Invalidating consumer: i140 - - expr_evaluator_.invalidate(for_loop->index()); - for (auto consumer : allConsumerValsOf(for_loop->index())) { - if (!consumer->isA()) { - expr_evaluator_.invalidate(consumer); - } - } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); @@ -799,9 +704,6 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { "Allocation must be on a TensorView but got ", allocate->buffer()); auto* tv = allocate->buffer()->as(); - if (expr_evaluator_.isKnown(tv)) { - return; - } // Check the cache if enabled if (params_.use_allocation_cache) { diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 7c6a18865b1..dfa3c8212e3 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -15,6 +15,7 @@ #include "ir/builder.h" #include "ir/iostream.h" #include "ir/utils.h" +#include "iter_visitor.h" #include "kernel_ir.h" #include "multidevice/propagation.h" #include "multidevice/resharding.h" @@ -176,7 +177,6 @@ void lowerSegment( // If a value is already cloned, IrCloner::clone returns the cloned value // without cloning the value again. Expr* e = ir_cloner.clone(group.exprs().front()); - debug() << "Cloned e: " << e << std::endl; // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. @@ -227,8 +227,19 @@ void lowerSegment( innermost_scope.pushBack(allocate); } + if (auto* cp = dynamic_cast(c)) { + auto add_definition_chain = [&innermost_scope](Val* val) -> void { + for (Expr* expr : StmtSort::getExprsTo({val})) { + innermost_scope.pushBack(expr); + } + }; + add_definition_chain(cp->sendPeer()); + add_definition_chain(cp->recvPeer()); + } + Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); + debug() << "Cloned new_c: " << new_c << std::endl; auto* wait = IrBuilder::create(new_c); innermost_scope.pushBack(wait); From 6b988692f49fd99dd730e2fc11f53b2fd5b1ea5b Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 4 Mar 2026 14:42:34 -0800 Subject: [PATCH 06/10] merge --- csrc/host_ir/evaluator.cpp | 3 ++ csrc/host_ir/lower_to_communication.cpp | 13 ++++-- csrc/host_ir/lowering.cpp | 15 +++++-- csrc/multidevice/communication.cpp | 3 ++ csrc/multidevice/communication.h | 6 +-- csrc/multidevice/post_communication.cpp | 60 +++++++++++++++++++++++++ csrc/multidevice/post_communication.h | 9 ++++ 7 files changed, 100 insertions(+), 9 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 02ed2566b70..fa98e5392cd 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -709,6 +709,9 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { "Allocation must be on a TensorView but got ", allocate->buffer()); auto* tv = allocate->buffer()->as(); + if (expr_evaluator_.isKnown(tv)) { + return; + } // Check the cache if enabled if (params_.use_allocation_cache) { diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index b5a05362bc3..2c1a70a3be5 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -426,15 +426,18 @@ std::optional getCommunicationInfoForParallelType( if (p_loop_id && !c_loop_id) { // Check if we are going from DID -> Stream, which is a ring allgather. - // This can be executed as a broadcast or send recvs, which is decided + // This can be executed as a broadcast or collective permute, which is decided // by the presence of a swizzle in the stream id definition. if (c_logical_stream_id == p2c.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh.") + CommunicationType type = c_stream_id->definition()->isA() + ? CommunicationType::Broadcast + : CommunicationType::CollectivePermute; return CommunicationInfo{ - .type = CommunicationType::Broadcast, + .type = type, .p_sharded_id = p_logical_id, .c_sharded_id = c_logical_stream_id}; } @@ -552,7 +555,8 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::AllToAll || + type == CommunicationType::CollectivePermute) { return layout; } @@ -694,6 +698,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::CollectivePermute: + lowerToCollectivePermute(input_tv, output_tv, backend, comms, root, my_device_idx); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index b5a1009a013..d56347b506c 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -234,13 +234,22 @@ void lowerSegment( Val* root = loop_nest.empty() ? nullptr : innermost.loop->index(); for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) { NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", + c->isA() || c->isA(), + "Exprs in a Communication group should be Communication or CollectivePermute: ", c); + if (auto* cp = dynamic_cast(c)) { + auto add_definition_chain = [&innermost_scope](Val* val) -> void { + for (Expr* expr : StmtSort::getExprsTo({val})) { + innermost_scope.pushBack(expr); + } + }; + add_definition_chain(cp->sendPeer()); + add_definition_chain(cp->recvPeer()); + } + Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); - debug() << "Cloned new_c: " << new_c << std::endl; auto* wait = IrBuilder::create(new_c); innermost_scope.pushBack(wait); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 151441d77d0..1d701d13062 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -46,6 +46,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::CollectivePermute: + os << "CollectivePermute"; + break; } return os; } diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index d2598a7cdb7..4719180f8f1 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -31,7 +31,8 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + CollectivePermute }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -123,8 +124,7 @@ class Communication : public Expr { }; // CollectivePermute: send to send_peer, recv from recv_peer. Separate from -// Communication (no root, no reduce op). Layout: inputs [in, send_peer, -// recv_peer], output [out], attributes [type, team, backend]. +// Communication (no root, no reduce op). class CollectivePermute : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/post_communication.cpp b/csrc/multidevice/post_communication.cpp index fd04ad26eda..0cd275b03fd 100644 --- a/csrc/multidevice/post_communication.cpp +++ b/csrc/multidevice/post_communication.cpp @@ -447,6 +447,33 @@ c10::intrusive_ptr postRecv( return backend->recv(packed_buffer, static_cast(peer), /*tag=*/0); } +c10::intrusive_ptr postCollectivePermute( + CollectivePermute* communication, + DeviceIdxType my_device_index, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { +if (my_device_index == send_peer_index && + my_device_index == recv_peer_index) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; +} +backend->startCoalescing(); +std::vector send_tensors = {input_tensor}; +backend->send( + send_tensors, + send_peer_index, + /*tag=*/0); +std::vector recv_tensors = {output_tensor}; +backend->recv( + recv_tensors, + recv_peer_index, + /*tag=*/0); +return backend->endCoalescing(); +} + } // namespace c10::intrusive_ptr postSingleCommunication( @@ -561,4 +588,37 @@ c10::intrusive_ptr postSingleCommunication( } } +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index) { +const Team& team = communication->team(); +if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { + return nullptr; +} +NVF_CHECK(backend != nullptr); + +if (isDebugDumpEnabled(DebugDumpOption::Communication) && + my_device_index == 0) { + debug() << "Posting " << communication->toInlineString() + << " with input_tensor " << input_tensor.sizes() + << " and output_tensor " << output_tensor.sizes() + << " send_peer=" << send_peer_index + << " recv_peer=" << recv_peer_index << std::endl; +} + +return postCollectivePermute( + communication, + my_device_index, + send_peer_index, + recv_peer_index, + backend, + input_tensor, + output_tensor); +} + } // namespace nvfuser diff --git a/csrc/multidevice/post_communication.h b/csrc/multidevice/post_communication.h index 18be2b537af..1ff2c983485 100644 --- a/csrc/multidevice/post_communication.h +++ b/csrc/multidevice/post_communication.h @@ -85,4 +85,13 @@ c10::intrusive_ptr postSingleCommunication( c10d::Backend* backend, at::Tensor buffer); +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index); + } // namespace nvfuser From 989065220b9983a41a7e3146cd7919a86380c788 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 4 Mar 2026 15:18:01 -0800 Subject: [PATCH 07/10] add swizzle1d to computeLoopIndex --- csrc/host_ir/lower_to_communication.cpp | 4 +- csrc/multidevice/resharding.cpp | 36 +++++++++++- tests/cpp/test_resharding.cpp | 73 +++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 5 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 2c1a70a3be5..80c9fa9e9e3 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -434,8 +434,8 @@ std::optional getCommunicationInfoForParallelType( "Broadcast based allgather in stream parallel requires same " "mesh.") CommunicationType type = c_stream_id->definition()->isA() - ? CommunicationType::Broadcast - : CommunicationType::CollectivePermute; + ? CommunicationType::CollectivePermute + : CommunicationType::Broadcast; return CommunicationInfo{ .type = type, .p_sharded_id = p_logical_id, diff --git a/csrc/multidevice/resharding.cpp b/csrc/multidevice/resharding.cpp index 245927fc4be..b71a0869e9b 100644 --- a/csrc/multidevice/resharding.cpp +++ b/csrc/multidevice/resharding.cpp @@ -59,7 +59,8 @@ const std::vector& getDomainOf( std::pair computeLoopIndex( IterDomain* id, const std::vector& sources, - std::unordered_map>& id_to_index) { + std::unordered_map>& id_to_index, + const std::unordered_map& pt_to_device_index) { if (id == nullptr) { return {nullptr, false}; } @@ -96,6 +97,18 @@ std::pair computeLoopIndex( id_to_index[out] = { add(mul(outer_info.first, inner->extent()), inner_info.first), outer_info.second || inner_info.second}; + } else if (auto* swizzle = dynamic_cast(transform)) { + auto* in = swizzle->in()->as(); + auto* out = swizzle->out()->as(); + + const auto& in_info = id_to_index.at(in); + Val* extent = out->extent(); + Val* pt_val = pt_to_device_index.at(swizzle->parallelType()); + // Inverse of the swizzle formula in_idx = (out_idx + pt_val) % extent: + // out_idx = (in_idx - pt_val + extent) % extent + id_to_index[out] = { + mod(add(sub(in_info.first, pt_val), extent), extent), + in_info.second}; } else { NVF_THROW("Unexpected transform: ", transform); } @@ -303,6 +316,17 @@ bool haveDifferentShardings( } } + // Create symbolic Vals for each device parallel type, representing the + // device's index within the team for that type. These are used by + // computeLoopIndex to symbolically compute Swizzle1D outputs. + std::unordered_map pt_to_device_index; + for (ParallelType pt : kParallelTypeDIDs) { + Val* device_idx = IrBuilder::create(DataType::Index); + pt_to_device_index[pt] = device_idx; + assumptions.push_back( + SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); + } + // For each parallel type, check whether the corresponding loop index in the // producer and that in the consumer are equivalent. If they can't be proven // to be equivalent, return is-resharding. @@ -311,7 +335,10 @@ bool haveDifferentShardings( Val* p_index = nullptr; bool p_mapped = false; std::tie(p_index, p_mapped) = computeLoopIndex( - p_id, getDomainOf(producer, DomainType::kLogical), id_to_index); + p_id, + getDomainOf(producer, DomainType::kLogical), + id_to_index, + pt_to_device_index); if (!p_mapped) { p_index = nullptr; } @@ -320,7 +347,10 @@ bool haveDifferentShardings( Val* c_index = nullptr; bool c_mapped = false; std::tie(c_index, c_mapped) = computeLoopIndex( - c_id, getDomainOf(consumer, DomainType::kRoot), id_to_index); + c_id, + getDomainOf(consumer, DomainType::kRoot), + id_to_index, + pt_to_device_index); if (!c_mapped) { c_index = nullptr; } diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index 7bcd950a1a2..7ef48b95ba2 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -631,4 +631,77 @@ TEST_F(ReshardingSelectOpTest, ReshardingSelectIntoNonDeviceDim) { EXPECT_TRUE(isResharding(tv1->definition())); } +// DID -> Stream with Swizzle (the collective permute pattern). +// Producer is sharded on DIDx, consumer has swizzle1d + Stream. +TEST_F(ReshardingTest, Swizzle1D_DIDToStream) { + Fusion fusion; + FusionGuard fg(&fusion); + const int d = 2; + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + in->outer_split(0, d); + in->axis(0)->parallelize(ParallelType::DIDx); + + TensorView* out = set(in); + out->setDeviceMesh(mesh); + out->outer_split(0, d); + out->swizzle1d(0, ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_TRUE(haveDifferentShardings( + in, + DomainType::kAllocation, + out, + DomainType::kLoop, + {ParallelType::Stream})); +} + +// Both sides have the same split + swizzle1d + Stream. +// Symbolic expressions should match, so this is not resharding. +TEST_F(ReshardingTest, Swizzle1D_ConsistentSwizzle) { + Fusion fusion; + FusionGuard fg(&fusion); + const int d = 2; + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + in->outer_split(0, d); + in->swizzle1d(0, ParallelType::DIDx); + in->axis(0)->parallelize(ParallelType::Stream); + + TensorView* out = set(in); + out->setDeviceMesh(mesh); + out->outer_split(0, d); + out->swizzle1d(0, ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_FALSE(haveDifferentShardings( + in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); +} + +// Same topology as DIDToStream, but checking DIDx only. +// The DIDx path doesn't traverse the swizzle (consumer has no DIDx loop ID). +TEST_F(ReshardingTest, Swizzle1D_DIDxCheck) { + Fusion fusion; + FusionGuard fg(&fusion); + const int d = 2; + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + in->outer_split(0, d); + in->axis(0)->parallelize(ParallelType::DIDx); + + TensorView* out = set(in); + out->setDeviceMesh(mesh); + out->outer_split(0, d); + out->swizzle1d(0, ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_TRUE(haveDifferentShardings(in, out, {ParallelType::DIDx})); +} + } // namespace nvfuser From d016f2d3e155d70126085383730cbe3d8d8a653b Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 4 Mar 2026 18:51:46 -0800 Subject: [PATCH 08/10] add upper bound --- csrc/multidevice/resharding.cpp | 40 ++++++++++++++++++++------------- tests/cpp/test_resharding.cpp | 35 +++++++---------------------- 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/csrc/multidevice/resharding.cpp b/csrc/multidevice/resharding.cpp index b71a0869e9b..b3640c0530e 100644 --- a/csrc/multidevice/resharding.cpp +++ b/csrc/multidevice/resharding.cpp @@ -60,7 +60,7 @@ std::pair computeLoopIndex( IterDomain* id, const std::vector& sources, std::unordered_map>& id_to_index, - const std::unordered_map& pt_to_device_index) { + const std::unordered_map& pt_to_index) { if (id == nullptr) { return {nullptr, false}; } @@ -103,7 +103,7 @@ std::pair computeLoopIndex( const auto& in_info = id_to_index.at(in); Val* extent = out->extent(); - Val* pt_val = pt_to_device_index.at(swizzle->parallelType()); + Val* pt_val = pt_to_index.at(swizzle->parallelType()); // Inverse of the swizzle formula in_idx = (out_idx + pt_val) % extent: // out_idx = (in_idx - pt_val + extent) % extent id_to_index[out] = { @@ -254,8 +254,27 @@ bool haveDifferentShardings( std::vector assumptions; assumptions.reserve( (producer->getLogicalDomain().size() + - consumer->getMaybeRootDomain().size()) * + consumer->getMaybeRootDomain().size() + + kParallelTypeDIDs.size()) * 2); + + // Create symbolic Vals for each device parallel type present in the mesh, + // representing the device's index within the team for that type. These are + // used by computeLoopIndex to symbolically compute Swizzle1D outputs. + std::unordered_map pt_to_index; + const DeviceMesh& mesh = producer->getDeviceMesh(); + for (ParallelType pt : kParallelTypeDIDs) { + if (!mesh.hasParallelType(pt)) { + continue; + } + Val* device_idx = IrBuilder::create(DataType::Index); + pt_to_index[pt] = device_idx; + Val* team_size = IrBuilder::create(mesh.size(pt), DataType::Index); + assumptions.push_back( + SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); + assumptions.push_back( + SimplifyingIrBuilder::ltExpr(device_idx, team_size)); + } auto create_index = [&](IterDomain* id, bool mapped) { auto* index = IrBuilder::create(DataType::Index); @@ -316,17 +335,6 @@ bool haveDifferentShardings( } } - // Create symbolic Vals for each device parallel type, representing the - // device's index within the team for that type. These are used by - // computeLoopIndex to symbolically compute Swizzle1D outputs. - std::unordered_map pt_to_device_index; - for (ParallelType pt : kParallelTypeDIDs) { - Val* device_idx = IrBuilder::create(DataType::Index); - pt_to_device_index[pt] = device_idx; - assumptions.push_back( - SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); - } - // For each parallel type, check whether the corresponding loop index in the // producer and that in the consumer are equivalent. If they can't be proven // to be equivalent, return is-resharding. @@ -338,7 +346,7 @@ bool haveDifferentShardings( p_id, getDomainOf(producer, DomainType::kLogical), id_to_index, - pt_to_device_index); + pt_to_index); if (!p_mapped) { p_index = nullptr; } @@ -350,7 +358,7 @@ bool haveDifferentShardings( c_id, getDomainOf(consumer, DomainType::kRoot), id_to_index, - pt_to_device_index); + pt_to_index); if (!c_mapped) { c_index = nullptr; } diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index 7ef48b95ba2..d3a1427415c 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -631,8 +631,6 @@ TEST_F(ReshardingSelectOpTest, ReshardingSelectIntoNonDeviceDim) { EXPECT_TRUE(isResharding(tv1->definition())); } -// DID -> Stream with Swizzle (the collective permute pattern). -// Producer is sharded on DIDx, consumer has swizzle1d + Stream. TEST_F(ReshardingTest, Swizzle1D_DIDToStream) { Fusion fusion; FusionGuard fg(&fusion); @@ -652,14 +650,19 @@ TEST_F(ReshardingTest, Swizzle1D_DIDToStream) { EXPECT_TRUE(haveDifferentShardings( in, - DomainType::kAllocation, + DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); + + EXPECT_TRUE(haveDifferentShardings( + in, + DomainType::kLoop, + out, + DomainType::kLoop, + {ParallelType::DIDx})); } -// Both sides have the same split + swizzle1d + Stream. -// Symbolic expressions should match, so this is not resharding. TEST_F(ReshardingTest, Swizzle1D_ConsistentSwizzle) { Fusion fusion; FusionGuard fg(&fusion); @@ -682,26 +685,4 @@ TEST_F(ReshardingTest, Swizzle1D_ConsistentSwizzle) { in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); } -// Same topology as DIDToStream, but checking DIDx only. -// The DIDx path doesn't traverse the swizzle (consumer has no DIDx loop ID). -TEST_F(ReshardingTest, Swizzle1D_DIDxCheck) { - Fusion fusion; - FusionGuard fg(&fusion); - const int d = 2; - auto mesh = DeviceMesh::createForNumDevices(d); - - TensorView* in = makeContigTensor(1); - in->setDeviceMesh(mesh); - in->outer_split(0, d); - in->axis(0)->parallelize(ParallelType::DIDx); - - TensorView* out = set(in); - out->setDeviceMesh(mesh); - out->outer_split(0, d); - out->swizzle1d(0, ParallelType::DIDx); - out->axis(0)->parallelize(ParallelType::Stream); - - EXPECT_TRUE(haveDifferentShardings(in, out, {ParallelType::DIDx})); -} - } // namespace nvfuser From 80b79093540aac9c0899561a5c8390c3b7f41dd4 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 4 Mar 2026 18:52:01 -0800 Subject: [PATCH 09/10] add upper bound --- csrc/multidevice/resharding.cpp | 2 +- tests/cpp/test_resharding.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/resharding.cpp b/csrc/multidevice/resharding.cpp index b3640c0530e..bfdb04c152c 100644 --- a/csrc/multidevice/resharding.cpp +++ b/csrc/multidevice/resharding.cpp @@ -257,7 +257,7 @@ bool haveDifferentShardings( consumer->getMaybeRootDomain().size() + kParallelTypeDIDs.size()) * 2); - + // Create symbolic Vals for each device parallel type present in the mesh, // representing the device's index within the team for that type. These are // used by computeLoopIndex to symbolically compute Swizzle1D outputs. diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index d3a1427415c..a20db841652 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -654,7 +654,7 @@ TEST_F(ReshardingTest, Swizzle1D_DIDToStream) { out, DomainType::kLoop, {ParallelType::Stream})); - + EXPECT_TRUE(haveDifferentShardings( in, DomainType::kLoop, From a7bc714e5dc94afd8f0c4ac82440bd60f090ae71 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 4 Mar 2026 19:27:52 -0800 Subject: [PATCH 10/10] update the test and replay --- csrc/runtime/allocations.cpp | 25 +++++++++++++++ tests/python/multidevice/test_overlap.py | 39 ++++++++++++++++-------- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index a6d372f33b4..d94be664678 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -9,6 +9,7 @@ #include "runtime/allocations.h" #include "expr_evaluator.h" +#include "ir/internal_nodes.h" #include "instrumentation.h" #include "multidevice/execution_utils.h" #include "multidevice/utils.h" @@ -595,11 +596,23 @@ class ForwardTraverseFromAllocToLogical { } } + void handle(Swizzle1D* swizzle1d) { + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = std::find(frontier_.begin(), frontier_.end(), in); + if (in_it == frontier_.end()) { + return; + } + *in_it = out; + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -728,11 +741,23 @@ class BackwardTraverseFromAllocToLogical { frontier_.erase(out_it); } + void handle(Swizzle1D* swizzle1d) { + auto out = swizzle1d->out(); + auto in = swizzle1d->in(); + auto out_it = std::find(frontier_.begin(), frontier_.end(), out); + if (out_it == frontier_.end()) { + return; + } + *out_it = in; + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 258d3fe9d31..017df1fe683 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,7 +417,7 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) -def column_parallel_linear_forward(h: int, d: int): +def column_parallel_linear_forward(h: int, d: int, parallelism: str): with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) weight_tv = fd.define_tensor( @@ -436,6 +436,8 @@ def column_parallel_linear_forward(h: int, d: int): ag_out.set_device_mesh(mesh) ag_out.outer_split(0, d) + if parallelism == "collective_permute": + ag_out.swizzle1d(0, nvfuser.ParallelType.mesh_x) ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) # Fusion IR before segmentation will look like this: @@ -444,11 +446,15 @@ def column_parallel_linear_forward(h: int, d: int): # d # (deviceIdx.x) # | - # | set (lowered to Broadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | set (lowered to Broadcast/CollectivePermute. This decomposition is done + # | manually in the definition above. It will later be done + # | by preseg.) # | # [t, h] [4h, h] # /\ /\. - # s d + # d d + # | swizzle1d (if parallelism == "collective_permute") + # s # (streamIdx) # | # | linear @@ -461,21 +467,22 @@ def column_parallel_linear_forward(h: int, d: int): @pytest.mark.mpi -def test_column_parallel_linear_forward(multidevice_test): +@pytest.mark.parametrize("parallelism", ["collective_permute", "broadcast"]) +def test_column_parallel_linear_forward(multidevice_test, parallelism: str): # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. # The difference is we are using broadcast based overlapping instead of send/recv. h, t = 2, 24 d = multidevice_test.size if (h * 4) % d != 0: pytest.skip( - f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + f"Column-parallel linear requires {h * 4} to be divisible by world size {d}." ) if t % d != 0: pytest.skip( f"Column-parallel linear requires {t} to be divisible by world size {d}." ) - fd = column_parallel_linear_forward(h, d) + fd = column_parallel_linear_forward(h, d, parallelism) inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( torch.bfloat16 @@ -492,15 +499,23 @@ def test_column_parallel_linear_forward(multidevice_test): with torch.profiler.profile(record_shapes=True) as prof: (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) torch.testing.assert_close(out, out_ref) - broadcast_events = [ - event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name - ] - assert len(broadcast_events) == (d if d > 1 else 0) + + if parallelism == "collective_permute": + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1) + else: + broadcast_events = [ + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == (d if d > 1 else 0) @pytest.mark.mpi @pytest.mark.benchmark -def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): +@pytest.mark.parametrize("parallelism", ["collective_permute", "broadcast"]) +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark, parallelism: str): # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. h, t = 8192, 8192 d = multidevice_test.size @@ -513,7 +528,7 @@ def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): f"Column-parallel linear requires {t} to be divisible by world size {d}." ) - fd = column_parallel_linear_forward(h, d) + fd = column_parallel_linear_forward(h, d, parallelism) inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu")