diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 05e2e98c569..45174de5a72 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -128,6 +128,7 @@ class Val; f(SdpaFwdOp); \ f(SdpaBwdOp); \ f(EmbeddingFwdOp); \ + f(CollectivePermute); \ f(Communication); \ f(P2PCommunication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 020a4ddb3a1..fa98e5392cd 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -310,6 +310,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) { ipc_handle_cache_.exchangeHandles(share_mem_handles->communications()); } +void HostIrEvaluator::handle(CollectivePermute* communication) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0)); + at::Tensor output_tensor = + getKnownTensorOrUndefined(communication->output(0)); + +#ifndef NDEBUG + validateSizesAndStrides( + {input_tensor, output_tensor}, + {communication->in(), communication->out()}, + expr_evaluator_); +#endif + + CommunicatorBackend backend_type = communication->backend(); + // CollectivePermute is only supported with NCCL backend because + // UCC does not support coalescing. + NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl); + c10d::Backend* backend = + communicator_->getBackendForTeam(communication->team(), backend_type); + works_[communication] = postSingleCommunication( + communication, + communicator_->deviceId(), + backend, + input_tensor, + output_tensor, + expr_evaluator_.evaluate(communication->sendPeer()).as(), + expr_evaluator_.evaluate(communication->recvPeer()).as()); +} + void HostIrEvaluator::handle(Communication* communication) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index f2c26c15797..e1d499addd8 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(Synchronize*) override; void handle(PostOnStream*) override; void handle(LaunchKernel*) override; + void handle(CollectivePermute*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; void handle(MoeDispatch*) override; diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 1519c185c9c..4ea35034fa6 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -258,9 +258,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) this, "must be registered in a HostIrContainer"); NVF_ERROR( - (expr->isOneOf()), - expr, - " must be a Communication, a P2PCommunication, or a EndCoalescing"); + (expr->isOneOf< + Communication, + CollectivePermute, + P2PCommunication, + EndCoalescing>()), + "Got: ", + expr); } NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 8717873d063..80c9fa9e9e3 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -334,6 +334,33 @@ void lowerToAllToAll( backend)); } +void lowerToCollectivePermute( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root, + DeviceIdxType my_device_idx) { + NVF_ERROR_EQ( + input_tv->getDeviceMesh(), + output_tv->getDeviceMesh(), + "CollectivePermute sender and receiver meshes must be the same. Given ", + input_tv->getDeviceMesh(), + " and ", + output_tv->getDeviceMesh()); + + IterDomain* stream_id = + getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop); + Swizzle1D* swizzle = stream_id->definition()->as(); + ParallelType pt = swizzle->parallelType(); + + const auto& [recv_peer, send_peer] = + dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); + Team team = input_tv->getDeviceMesh().vector(); + comms.push_back(IrBuilder::create( + output_tv, input_tv, team, send_peer, recv_peer, backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::vector logical_ids = ir_utils::getReachableIds(tv->getLogicalDomain(), {loop_id}); @@ -399,15 +426,18 @@ std::optional getCommunicationInfoForParallelType( if (p_loop_id && !c_loop_id) { // Check if we are going from DID -> Stream, which is a ring allgather. - // This can be executed as a broadcast or send recvs, which is decided + // This can be executed as a broadcast or collective permute, which is decided // by the presence of a swizzle in the stream id definition. if (c_logical_stream_id == p2c.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh.") + CommunicationType type = c_stream_id->definition()->isA() + ? CommunicationType::CollectivePermute + : CommunicationType::Broadcast; return CommunicationInfo{ - .type = CommunicationType::Broadcast, + .type = type, .p_sharded_id = p_logical_id, .c_sharded_id = c_logical_stream_id}; } @@ -525,7 +555,8 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::AllToAll || + type == CommunicationType::CollectivePermute) { return layout; } @@ -667,6 +698,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::CollectivePermute: + lowerToCollectivePermute(input_tv, output_tv, backend, comms, root, my_device_idx); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index b20e226fda5..d56347b506c 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -12,8 +12,11 @@ #include "host_ir/ir.h" #include "host_ir/lower_to_communication.h" #include "host_ir/ops.h" +#include "ir/builder.h" #include "ir/iostream.h" #include "ir/utils.h" +#include "iter_visitor.h" +#include "kernel_ir.h" #include "multidevice/propagation.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" @@ -231,10 +234,20 @@ void lowerSegment( Val* root = loop_nest.empty() ? nullptr : innermost.loop->index(); for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) { NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", + c->isA() || c->isA(), + "Exprs in a Communication group should be Communication or CollectivePermute: ", c); + if (auto* cp = dynamic_cast(c)) { + auto add_definition_chain = [&innermost_scope](Val* val) -> void { + for (Expr* expr : StmtSort::getExprsTo({val})) { + innermost_scope.pushBack(expr); + } + }; + add_definition_chain(cp->sendPeer()); + add_definition_chain(cp->recvPeer()); + } + Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 1d726072147..1d701d13062 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -46,6 +46,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::CollectivePermute: + os << "CollectivePermute"; + break; } return os; } @@ -64,6 +67,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: case CommunicationType::AllToAll: + case CommunicationType::CollectivePermute: return false; } std::unreachable(); @@ -216,6 +220,47 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +CollectivePermute::CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend) + : Expr(passkey) { + NVF_ERROR( + in->getDeviceMesh().size() > 0, + "The input mesh size must be greater than 0."); + NVF_ERROR( + out->getDeviceMesh().size() > 0, + "The output mesh size must be greater than 0."); + addInput(in); + addInput(send_peer); + addInput(recv_peer); + addOutput(out); + addDataAttribute(CommunicationType::CollectivePermute); + addDataAttribute(team); + addDataAttribute(backend); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(CollectivePermute) + +std::string CollectivePermute::toInlineString(const int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "CollectivePermute " << name() << " (" + << "team=(" << team() << ")" + << ", send_peer=" << sendPeer()->toInlineString() + << ", recv_peer=" << recvPeer()->toInlineString() + << ", input=" << in() << ", output=" << out() + << ", backend=" << backend() << ")"; + return ss.str(); +} + +std::string CollectivePermute::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + MoeDispatch::MoeDispatch( IrBuilderPasskey passkey, TensorView* out_x, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 4c387b1695e..4719180f8f1 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -31,7 +31,8 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + CollectivePermute }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -122,6 +123,61 @@ class Communication : public Expr { void validate(); }; +// CollectivePermute: send to send_peer, recv from recv_peer. Separate from +// Communication (no root, no reduce op). +class CollectivePermute : public Expr { + public: + using Expr::Expr; + + CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + CollectivePermute(const CollectivePermute& other) = delete; + CollectivePermute& operator=(const CollectivePermute& other) = delete; + CollectivePermute(CollectivePermute&& other) = delete; + CollectivePermute& operator=(CollectivePermute&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "CollectivePermute"; + } + + CommunicationType type() const { + return attribute(0); + } + + TensorView* in() const { + return input(0)->as(); + } + TensorView* out() const { + return output(0)->as(); + } + Val* sendPeer() const { + return input(1); + } + Val* recvPeer() const { + return input(2); + } + const Team& team() const { + return attribute(1); + } + int64_t team_size() const { + return static_cast(team().size()); + } + CommunicatorBackend backend() const { + return attribute(2); + } +}; + enum class P2PCommunicationType { SEND, RECV }; std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); diff --git a/csrc/multidevice/post_communication.cpp b/csrc/multidevice/post_communication.cpp index fd04ad26eda..0cd275b03fd 100644 --- a/csrc/multidevice/post_communication.cpp +++ b/csrc/multidevice/post_communication.cpp @@ -447,6 +447,33 @@ c10::intrusive_ptr postRecv( return backend->recv(packed_buffer, static_cast(peer), /*tag=*/0); } +c10::intrusive_ptr postCollectivePermute( + CollectivePermute* communication, + DeviceIdxType my_device_index, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { +if (my_device_index == send_peer_index && + my_device_index == recv_peer_index) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; +} +backend->startCoalescing(); +std::vector send_tensors = {input_tensor}; +backend->send( + send_tensors, + send_peer_index, + /*tag=*/0); +std::vector recv_tensors = {output_tensor}; +backend->recv( + recv_tensors, + recv_peer_index, + /*tag=*/0); +return backend->endCoalescing(); +} + } // namespace c10::intrusive_ptr postSingleCommunication( @@ -561,4 +588,37 @@ c10::intrusive_ptr postSingleCommunication( } } +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index) { +const Team& team = communication->team(); +if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { + return nullptr; +} +NVF_CHECK(backend != nullptr); + +if (isDebugDumpEnabled(DebugDumpOption::Communication) && + my_device_index == 0) { + debug() << "Posting " << communication->toInlineString() + << " with input_tensor " << input_tensor.sizes() + << " and output_tensor " << output_tensor.sizes() + << " send_peer=" << send_peer_index + << " recv_peer=" << recv_peer_index << std::endl; +} + +return postCollectivePermute( + communication, + my_device_index, + send_peer_index, + recv_peer_index, + backend, + input_tensor, + output_tensor); +} + } // namespace nvfuser diff --git a/csrc/multidevice/post_communication.h b/csrc/multidevice/post_communication.h index 18be2b537af..1ff2c983485 100644 --- a/csrc/multidevice/post_communication.h +++ b/csrc/multidevice/post_communication.h @@ -85,4 +85,13 @@ c10::intrusive_ptr postSingleCommunication( c10d::Backend* backend, at::Tensor buffer); +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index); + } // namespace nvfuser diff --git a/csrc/multidevice/resharding.cpp b/csrc/multidevice/resharding.cpp index 245927fc4be..bfdb04c152c 100644 --- a/csrc/multidevice/resharding.cpp +++ b/csrc/multidevice/resharding.cpp @@ -59,7 +59,8 @@ const std::vector& getDomainOf( std::pair computeLoopIndex( IterDomain* id, const std::vector& sources, - std::unordered_map>& id_to_index) { + std::unordered_map>& id_to_index, + const std::unordered_map& pt_to_index) { if (id == nullptr) { return {nullptr, false}; } @@ -96,6 +97,18 @@ std::pair computeLoopIndex( id_to_index[out] = { add(mul(outer_info.first, inner->extent()), inner_info.first), outer_info.second || inner_info.second}; + } else if (auto* swizzle = dynamic_cast(transform)) { + auto* in = swizzle->in()->as(); + auto* out = swizzle->out()->as(); + + const auto& in_info = id_to_index.at(in); + Val* extent = out->extent(); + Val* pt_val = pt_to_index.at(swizzle->parallelType()); + // Inverse of the swizzle formula in_idx = (out_idx + pt_val) % extent: + // out_idx = (in_idx - pt_val + extent) % extent + id_to_index[out] = { + mod(add(sub(in_info.first, pt_val), extent), extent), + in_info.second}; } else { NVF_THROW("Unexpected transform: ", transform); } @@ -241,9 +254,28 @@ bool haveDifferentShardings( std::vector assumptions; assumptions.reserve( (producer->getLogicalDomain().size() + - consumer->getMaybeRootDomain().size()) * + consumer->getMaybeRootDomain().size() + + kParallelTypeDIDs.size()) * 2); + // Create symbolic Vals for each device parallel type present in the mesh, + // representing the device's index within the team for that type. These are + // used by computeLoopIndex to symbolically compute Swizzle1D outputs. + std::unordered_map pt_to_index; + const DeviceMesh& mesh = producer->getDeviceMesh(); + for (ParallelType pt : kParallelTypeDIDs) { + if (!mesh.hasParallelType(pt)) { + continue; + } + Val* device_idx = IrBuilder::create(DataType::Index); + pt_to_index[pt] = device_idx; + Val* team_size = IrBuilder::create(mesh.size(pt), DataType::Index); + assumptions.push_back( + SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); + assumptions.push_back( + SimplifyingIrBuilder::ltExpr(device_idx, team_size)); + } + auto create_index = [&](IterDomain* id, bool mapped) { auto* index = IrBuilder::create(DataType::Index); NVF_ERROR(id_to_index.emplace(id, std::make_pair(index, mapped)).second); @@ -311,7 +343,10 @@ bool haveDifferentShardings( Val* p_index = nullptr; bool p_mapped = false; std::tie(p_index, p_mapped) = computeLoopIndex( - p_id, getDomainOf(producer, DomainType::kLogical), id_to_index); + p_id, + getDomainOf(producer, DomainType::kLogical), + id_to_index, + pt_to_index); if (!p_mapped) { p_index = nullptr; } @@ -320,7 +355,10 @@ bool haveDifferentShardings( Val* c_index = nullptr; bool c_mapped = false; std::tie(c_index, c_mapped) = computeLoopIndex( - c_id, getDomainOf(consumer, DomainType::kRoot), id_to_index); + c_id, + getDomainOf(consumer, DomainType::kRoot), + id_to_index, + pt_to_index); if (!c_mapped) { c_index = nullptr; } diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 95240795da2..4a0b036fe5f 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -14,8 +14,10 @@ #include #include "compute_at_map.h" +#include "ir/builder.h" #include "ir/internal_base_nodes.h" #include "ir/internal_nodes.h" +#include "ops/arith.h" #include "transform_replay.h" #include "type.h" @@ -346,4 +348,21 @@ int64_t getRelativeIndex(const Team& team, DeviceIdxType rank) { return std::distance(team.begin(), i); } +std::pair dispatchSwizzle1D( + Val* host_loop_index, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh) { + int64_t team_size = mesh.size(pt); + at::Tensor md_index = mesh.multiDimensionalIndexOf(device_id); + auto pt_axis = mesh.parallelTypeToAxis(pt); + int64_t team_index = md_index[pt_axis].item(); + Val* team_size_val = IrBuilder::create(team_size, DataType::Index); + Val* team_index_val = IrBuilder::create(team_index, DataType::Index); + return std::make_pair( + mod(add(host_loop_index, team_index_val), team_size_val), + mod(add(team_size_val, sub(team_index_val, host_loop_index)), + team_size_val)); +} + } // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 467d9fa61e6..44eba74292e 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -91,4 +91,10 @@ int64_t getRFactorDeviceDimensionIndex(const TensorView* tv); // Returns the relative index of the rank in the team. int64_t getRelativeIndex(const Team& team, DeviceIdxType rank); +std::pair dispatchSwizzle1D( + Val* my_rank, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh); + } // namespace nvfuser diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index a6d372f33b4..d94be664678 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -9,6 +9,7 @@ #include "runtime/allocations.h" #include "expr_evaluator.h" +#include "ir/internal_nodes.h" #include "instrumentation.h" #include "multidevice/execution_utils.h" #include "multidevice/utils.h" @@ -595,11 +596,23 @@ class ForwardTraverseFromAllocToLogical { } } + void handle(Swizzle1D* swizzle1d) { + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = std::find(frontier_.begin(), frontier_.end(), in); + if (in_it == frontier_.end()) { + return; + } + *in_it = out; + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -728,11 +741,23 @@ class BackwardTraverseFromAllocToLogical { frontier_.erase(out_it); } + void handle(Swizzle1D* swizzle1d) { + auto out = swizzle1d->out(); + auto in = swizzle1d->in(); + auto out_it = std::find(frontier_.begin(), frontier_.end(), out); + if (out_it == frontier_.end()) { + return; + } + *out_it = in; + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 2b2df8097ed..f4458598a1c 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -94,11 +94,24 @@ class ForwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = active_ids_.find(in); + auto [in_size, in_stride] = in_it->second; + NVF_ERROR(active_ids_.erase(in) == 1); + NVF_ERROR( + active_ids_.emplace(out, std::make_pair(in_size, in_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -189,11 +202,24 @@ class BackwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto out_it = active_ids_.find(out); + auto [out_size, out_stride] = out_it->second; + NVF_ERROR(active_ids_.erase(out) == 1); + NVF_ERROR( + active_ids_.emplace(in, std::make_pair(out_size, out_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index 93c032e4ef6..c149663da38 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -501,6 +501,27 @@ Returns ------- TensorView A TensorView with the swizzled axes in its loop domain. +)") + .def( + "swizzle1d", + [](TensorView* self, int64_t x, ParallelType parallel_type) { + return self->swizzle1d(x, parallel_type); + }, + py::return_value_policy::reference, + py::arg("x"), + py::arg("parallel_type"), + R"( +Swizzle the specified axis with the device index corresponding to the given parallel type. +Parameters +---------- +x : int +The axis to swizzle. +parallel_type : ParallelType +The device parallel type for the 1D swizzle. +Returns +------- +TensorView +A TensorView with the swizzled axis in its loop domain. )") .def( "rfactor", diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index 7bcd950a1a2..a20db841652 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -631,4 +631,58 @@ TEST_F(ReshardingSelectOpTest, ReshardingSelectIntoNonDeviceDim) { EXPECT_TRUE(isResharding(tv1->definition())); } +TEST_F(ReshardingTest, Swizzle1D_DIDToStream) { + Fusion fusion; + FusionGuard fg(&fusion); + const int d = 2; + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + in->outer_split(0, d); + in->axis(0)->parallelize(ParallelType::DIDx); + + TensorView* out = set(in); + out->setDeviceMesh(mesh); + out->outer_split(0, d); + out->swizzle1d(0, ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_TRUE(haveDifferentShardings( + in, + DomainType::kLoop, + out, + DomainType::kLoop, + {ParallelType::Stream})); + + EXPECT_TRUE(haveDifferentShardings( + in, + DomainType::kLoop, + out, + DomainType::kLoop, + {ParallelType::DIDx})); +} + +TEST_F(ReshardingTest, Swizzle1D_ConsistentSwizzle) { + Fusion fusion; + FusionGuard fg(&fusion); + const int d = 2; + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + in->outer_split(0, d); + in->swizzle1d(0, ParallelType::DIDx); + in->axis(0)->parallelize(ParallelType::Stream); + + TensorView* out = set(in); + out->setDeviceMesh(mesh); + out->outer_split(0, d); + out->swizzle1d(0, ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_FALSE(haveDifferentShardings( + in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); +} + } // namespace nvfuser diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 212029712d4..3b2e0716d2c 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -202,3 +202,32 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +def test_collective_permute(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Float) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) + + inp_ref = torch.randn(d * 3) + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + with torch.profiler.profile() as prof: + (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) + torch.testing.assert_close(out.cpu(), inp_ref) + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 258d3fe9d31..017df1fe683 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,7 +417,7 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) -def column_parallel_linear_forward(h: int, d: int): +def column_parallel_linear_forward(h: int, d: int, parallelism: str): with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) weight_tv = fd.define_tensor( @@ -436,6 +436,8 @@ def column_parallel_linear_forward(h: int, d: int): ag_out.set_device_mesh(mesh) ag_out.outer_split(0, d) + if parallelism == "collective_permute": + ag_out.swizzle1d(0, nvfuser.ParallelType.mesh_x) ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) # Fusion IR before segmentation will look like this: @@ -444,11 +446,15 @@ def column_parallel_linear_forward(h: int, d: int): # d # (deviceIdx.x) # | - # | set (lowered to Broadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | set (lowered to Broadcast/CollectivePermute. This decomposition is done + # | manually in the definition above. It will later be done + # | by preseg.) # | # [t, h] [4h, h] # /\ /\. - # s d + # d d + # | swizzle1d (if parallelism == "collective_permute") + # s # (streamIdx) # | # | linear @@ -461,21 +467,22 @@ def column_parallel_linear_forward(h: int, d: int): @pytest.mark.mpi -def test_column_parallel_linear_forward(multidevice_test): +@pytest.mark.parametrize("parallelism", ["collective_permute", "broadcast"]) +def test_column_parallel_linear_forward(multidevice_test, parallelism: str): # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. # The difference is we are using broadcast based overlapping instead of send/recv. h, t = 2, 24 d = multidevice_test.size if (h * 4) % d != 0: pytest.skip( - f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + f"Column-parallel linear requires {h * 4} to be divisible by world size {d}." ) if t % d != 0: pytest.skip( f"Column-parallel linear requires {t} to be divisible by world size {d}." ) - fd = column_parallel_linear_forward(h, d) + fd = column_parallel_linear_forward(h, d, parallelism) inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( torch.bfloat16 @@ -492,15 +499,23 @@ def test_column_parallel_linear_forward(multidevice_test): with torch.profiler.profile(record_shapes=True) as prof: (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) torch.testing.assert_close(out, out_ref) - broadcast_events = [ - event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name - ] - assert len(broadcast_events) == (d if d > 1 else 0) + + if parallelism == "collective_permute": + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1) + else: + broadcast_events = [ + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == (d if d > 1 else 0) @pytest.mark.mpi @pytest.mark.benchmark -def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): +@pytest.mark.parametrize("parallelism", ["collective_permute", "broadcast"]) +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark, parallelism: str): # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. h, t = 8192, 8192 d = multidevice_test.size @@ -513,7 +528,7 @@ def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): f"Column-parallel linear requires {t} to be divisible by world size {d}." ) - fd = column_parallel_linear_forward(h, d) + fd = column_parallel_linear_forward(h, d, parallelism) inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu")