diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index fb7a6a72b14..b20f1f23355 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -105,51 +105,90 @@ bool Fusion::sameDefinition(const Fusion& other) const { return true; } -void Fusion::swap(Fusion& a, Fusion& b) noexcept { +void Fusion::swap(Fusion& a, Fusion& b) { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrContainer swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrContainer::swap(*(a.ir_container()), *(b.ir_container())); + if (&a == &b) { + return; + } - // After swapping container contents, update Statement::ir_container_ - // pointers so each Statement points to the Fusion whose container now - // holds it. Also fix per-Fusion tracking keys since a's container had - // b's entries and vice versa. - a.ir_container()->transferStatementOwnership(&b, &a); - b.ir_container()->transferStatementOwnership(&a, &b); + NVF_ERROR( + a.ir_container_ != nullptr, "Fusion::swap: a has null ir_container_"); + NVF_ERROR( + b.ir_container_ != nullptr, "Fusion::swap: b has null ir_container_"); - if (a.ir_container_) { - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } - } - if (b.ir_container_) { - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } + // Collect statements owned by each Fusion BEFORE swap so we can update + // Statement::ir_container_ pointers afterward. + std::vector a_owned_vals, b_owned_vals; + std::vector a_owned_exprs, b_owned_exprs; + + const auto& av = a.ir_container_->valsOwnedBy(&a); + const auto& ae = a.ir_container_->exprsOwnedBy(&a); + a_owned_vals.assign(av.begin(), av.end()); + a_owned_exprs.assign(ae.begin(), ae.end()); + + const auto& bv = b.ir_container_->valsOwnedBy(&b); + const auto& be = b.ir_container_->exprsOwnedBy(&b); + b_owned_vals.assign(bv.begin(), bv.end()); + b_owned_exprs.assign(be.begin(), be.end()); + + // Transfer Fusion registrations between containers before pointer swap. + // After swap, a will own b's container and b will own a's container. + if (a.ir_container_.get() != b.ir_container_.get()) { + a.ir_container_->transferFusion(&a, &b); + b.ir_container_->transferFusion(&b, &a); } + // Swap container pointers + std::swap(a.ir_container_, b.ir_container_); + + // Swap all Fusion-level members std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); - std::swap(a.io_alias_, b.io_alias_); - - // Swap per-Fusion special values (Phase 2) + std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + std::swap(a.is_during_update_uses_, b.is_during_update_uses_); + std::swap(a.managed_data_, b.managed_data_); + std::swap(a.managed_named_data_, b.managed_named_data_); + std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); std::swap(a.zero_val_, b.zero_val_); std::swap(a.one_val_, b.one_val_); std::swap(a.true_val_, b.true_val_); std::swap(a.false_val_, b.false_val_); std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); std::swap(a.metadata_, b.metadata_); + std::swap(a.val_type_name_map_, b.val_type_name_map_); + std::swap(a.expr_name_counter_, b.expr_name_counter_); + + // Update Statement::ir_container_ pointers: a's old statements now belong + // to b, and b's old statements now belong to a + for (auto* val : a_owned_vals) { + val->ir_container_ = &b; + } + for (auto* expr : a_owned_exprs) { + expr->ir_container_ = &b; + } + for (auto* val : b_owned_vals) { + val->ir_container_ = &a; + } + for (auto* expr : b_owned_exprs) { + expr->ir_container_ = &a; + } + + // Update per-Fusion tracking keys in containers. At this point, both + // a and b are guaranteed to have non-null ir_container_ (verified above). + if (a.ir_container_.get() == b.ir_container_.get()) { + // Same container: directly swap per-Fusion tracking entries + auto* c = a.ir_container_.get(); + std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]); + std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]); + } else { + // Different containers: rename tracking keys to match new owners + a.ir_container_->transferStatementOwnership(&b, &a); + b.ir_container_->transferStatementOwnership(&a, &b); + } } std::unique_ptr Fusion::segment( @@ -161,10 +200,33 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = - IrContainer::copy(from->ir_container(), to->ir_container(), to); + IrCloner ir_cloner(to); - // Remap cached special val pointers through the cloner + // Clone from's vals in insertion order + for (auto val : from->deterministic_vals()) { + ir_cloner.clone(val); + } + + // Wire up definitions and uses on cloned vals in deterministic order + // to ensure exprs are inserted into exprs_up_ deterministically + for (auto val : from->deterministic_vals()) { + ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); + ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + } + + // Sync per-Fusion name counters from source to dest. + // Must be AFTER all cloning (vals and exprs) so that registerVal/registerExpr + // increments during cloning do not inflate the final counter values. + // During cloning, registerVal increments the dest Fusion's counter for each + // val, then IrBuilder::clone overrides the name with setName(src->name()). + // If source names are non-sequential (e.g., {0..10, 22..27} from segmenter + // creating intermediate TVs), the dest counter ends up at N (number of vals) + // instead of max(name)+1. Copying the source's counter state ensures new + // vals created post-copy won't collide with existing names. + to->val_type_name_map_ = from->val_type_name_map_; + to->expr_name_counter_ = from->expr_name_counter_; + + // Remap cached special val pointers if (from->zero_val_) { to->zero_val_ = ir_cloner.clone(from->zero_val_); } @@ -182,11 +244,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { ir_cloner.clone(from->magic_zero_val_)->as(); } - for (auto val : from->vals()) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); - } - to->inputs_ = ir_cloner.clone(from->inputs_); to->outputs_ = ir_cloner.clone(from->outputs_); for (auto inp : to->inputs_) { @@ -196,7 +253,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { out->setIsFusionOutput(true); } - // TODO: put this into ir_cloner instead for (Val* out : from->outputs_) { const AliasInfo& alias = from->io_alias_.get(out); if (alias.type == AllocationType::New) { @@ -209,14 +265,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } to->all_tv_uses_valid_ = from->all_tv_uses_valid_; - // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; for (const auto& i : from->managed_data_) { if (i.first.has_value()) { to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); } else { - // Don't clone managed data if it has been reset to->managed_data_.emplace_back(i.first, i.second); } } @@ -259,30 +313,43 @@ Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->addFusion(this); } -// Copy constructor -Fusion::Fusion(const Fusion& other) : Fusion() { +// Copy constructor -- shares the source's container +Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) { FUSER_PERF_SCOPE("Fusion copy"); + ir_container_->addFusion(this); Fusion::copy(&other, this); } // Move constructor -Fusion::Fusion(Fusion&& other) noexcept : Fusion() { +// Not marked noexcept: Fusion::swap allocates local std::vectors to collect +// statement ownership before the swap, which can throw. Since Fusions are not +// expected to be moved into containers, the performance trade-off is +// acceptable. +// NOLINTNEXTLINE(cppcoreguidelines-noexcept-move-operations) +Fusion::Fusion(Fusion&& other) : Fusion() { FUSER_PERF_SCOPE("Fusion move"); swap(*this, other); } +// Copy Assignment -- shares the source's container Fusion& Fusion::operator=(const Fusion& other) { FUSER_PERF_SCOPE("Fusion copy assign"); - Fusion copy(other); - clear(); - swap(*this, copy); + if (this != &other) { + Fusion copy(other); + clear(); + swap(*this, copy); + } return *this; } -Fusion& Fusion::operator=(Fusion&& other) noexcept { +// Not marked noexcept: See move constructor above. +// NOLINTNEXTLINE(cppcoreguidelines-noexcept-move-operations) +Fusion& Fusion::operator=(Fusion&& other) { FUSER_PERF_SCOPE("Fusion move assign"); - clear(); - swap(*this, other); + if (this != &other) { + clear(); + swap(*this, other); + } return *this; } @@ -320,6 +387,9 @@ void Fusion::clear() noexcept { axioms_.reset(); metadata_.clear(); + val_type_name_map_.clear(); + expr_name_counter_ = 0; + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -924,7 +994,7 @@ void Fusion::registerVal(Val* val) { c->vals_up_.emplace_back(val); c->vals_.insert(val); c->per_fusion_vals_[this].insert(val); - val->setName(IrContainerPasskey(), c->getValName(val->vtype())); + val->setName(IrContainerPasskey(), getValName(val->vtype())); } void Fusion::registerExpr(Expr* expr) { @@ -941,7 +1011,7 @@ void Fusion::registerExpr(Expr* expr) { c->exprs_up_.emplace_back(expr); c->exprs_.insert(expr); c->per_fusion_exprs_[this].insert(expr); - expr->setName(IrContainerPasskey(), c->getExprName()); + expr->setName(IrContainerPasskey(), getExprName()); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/fusion.h b/csrc/fusion.h index e44817405e7..34be84be28d 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -180,14 +180,23 @@ class NVF_API Fusion : public PolymorphicBase { Fusion(); Fusion(const Fusion& other); - Fusion(Fusion&& other) noexcept; + + // Not marked noexcept: Fusion::swap allocates local std::vectors to collect + // statement ownership before the swap, which can throw. Since Fusions are not + // expected to be moved into containers, the performance trade-off is + // acceptable. + // NOLINTNEXTLINE(cppcoreguidelines-noexcept-move-operations) + Fusion(Fusion&& other); Fusion& operator=(const Fusion& other); - Fusion& operator=(Fusion&& other) noexcept; + + // Not marked noexcept: See move constructor above. + // NOLINTNEXTLINE(cppcoreguidelines-noexcept-move-operations) + Fusion& operator=(Fusion&& other); ~Fusion() override; - static void swap(Fusion& a, Fusion& b) noexcept; + static void swap(Fusion& a, Fusion& b); void clear() noexcept; @@ -661,6 +670,21 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> axioms_; std::unordered_map> metadata_; + + // Per-Fusion name counters. Each Fusion independently tracks name assignment + // so that cloned Fusions get matching names (T0→T0) regardless of whether + // they share an IrContainer. This is required by downstream consumers that + // use tv->name() as a map key (alias_memory, GreedyParams, etc.). + std::unordered_map val_type_name_map_; + StmtNameType expr_name_counter_ = 0; + + StmtNameType getValName(ValType vtype) { + return val_type_name_map_[vtype]++; + } + + StmtNameType getExprName() { + return expr_name_counter_++; + } }; // Template implementations for Fusion::manage() that use IrCloner