From 0bfea77c4d84a6630bf0d761958ece9dc31312fc Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Thu, 12 Feb 2026 14:06:47 -0800 Subject: [PATCH 1/2] Per-Fusion statement tracking and ownership-filtered accessors Add per_fusion_vals_ / per_fusion_exprs_ maps to IrContainer so each Fusion can efficiently query only its own statements in a shared container. Fusion forwarding methods (vals(), unordered_exprs(), deterministic_vals(), etc.) now return per-Fusion filtered results. Fusion::clear() uses removeStatementsOwnedBy(this) instead of ir_container()->clear(). --- csrc/fusion.cpp | 24 ++++---- csrc/fusion.h | 22 ++++--- csrc/ir/container.cpp | 131 ++++++++++++++++++++++++++++++++++++++++++ csrc/ir/container.h | 18 ++++++ 4 files changed, 173 insertions(+), 22 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 01f1c300f09..61d950995e2 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -114,7 +114,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { // After swapping container contents, update Statement::ir_container_ // pointers so each Statement points to the Fusion whose container now - // holds it. + // holds it. Also fix per-Fusion tracking keys since a's container had + // b's entries and vice versa. + a.ir_container()->transferStatementOwnership(&b, &a); + b.ir_container()->transferStatementOwnership(&a, &b); + if (a.ir_container_) { for (auto val : a.vals()) { val->ir_container_ = &a; @@ -295,10 +299,9 @@ void Fusion::clear() noexcept { // constructor of Trace, which could throw an exception. // FUSER_PERF_SCOPE("Fusion clear"); - // Clear container contents instead of destroying it - // This preserves the container object so Statement pointers don't become - // dangling - ir_container()->clear(); + if (ir_container_) { + ir_container_->removeStatementsOwnedBy(this); + } inputs_.clear(); outputs_.clear(); @@ -308,8 +311,6 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - // Reset per-Fusion special value caches (the vals themselves are owned by - // ir_container and were already destroyed by ir_container()->clear() above). zero_val_ = nullptr; one_val_ = nullptr; true_val_ = nullptr; @@ -353,6 +354,7 @@ void Fusion::removeExpr(Expr* expr) { NVF_ERROR( expr_in_deque != c->exprs_up_.end(), "Wanted to remove an expression but its unique ptr is missing."); + c->per_fusion_exprs_[this].erase(expr); c->exprs_.erase(expr); c->exprs_up_.erase(expr_in_deque); } @@ -412,6 +414,7 @@ void Fusion::removeVal(Val* val) { NVF_ERROR( val_in_deque != c->vals_up_.end(), "Wanted to remove a value but its unique ptr is missing."); + c->per_fusion_vals_[this].erase(val); c->vals_.erase(val); c->vals_up_.erase(val_in_deque); @@ -444,13 +447,11 @@ void Fusion::removeStatementsCreatedAfter( for (Val* in : e->inputs()) { in->removeUse(e); } + c->per_fusion_exprs_[this].erase(e); c->exprs_.erase(e); c->exprs_up_.pop_back(); } - // Null out any special value caches that point to vals about to be destroyed. - // This prevents dangling pointers when special vals are lazily created inside - // a StatementGuard scope. while (std::ssize(c->vals_up_) > num_vals_before) { Val* v = c->vals_up_.back().get(); if (v == zero_val_) { @@ -464,6 +465,7 @@ void Fusion::removeStatementsCreatedAfter( } else if (v == magic_zero_val_) { magic_zero_val_ = nullptr; } + c->per_fusion_vals_[this].erase(v); c->vals_.erase(v); c->vals_up_.pop_back(); } @@ -927,6 +929,7 @@ void Fusion::registerVal(Val* val) { auto* c = ir_container(); c->vals_up_.emplace_back(val); c->vals_.insert(val); + c->per_fusion_vals_[this].insert(val); val->setName(IrContainerPasskey(), c->getValName(val->vtype())); } @@ -943,6 +946,7 @@ void Fusion::registerExpr(Expr* expr) { auto* c = ir_container(); c->exprs_up_.emplace_back(expr); c->exprs_.insert(expr); + c->per_fusion_exprs_[this].insert(expr); expr->setName(IrContainerPasskey(), c->getExprName()); for (Val* input : expr->inputs()) { diff --git a/csrc/fusion.h b/csrc/fusion.h index 6ba7a3788df..07353bfe08c 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -522,31 +522,29 @@ class NVF_API Fusion : public PolymorphicBase { } // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_container()->deterministic_vals(); + std::deque deterministic_vals() const noexcept { + return ir_container()->deterministicValsOwnedBy(this); } - const std::deque deterministic_exprs() const noexcept { - return ir_container()->deterministic_exprs(); + std::deque deterministic_exprs() const noexcept { + return ir_container()->deterministicExprsOwnedBy(this); } - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_container()->deterministic_vals_map(); + std::unordered_map deterministic_vals_map() const noexcept { + return ir_container()->deterministicValsMapOwnedBy(this); } - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_container()->deterministic_exprs_map(); + std::unordered_map deterministic_exprs_map() const noexcept { + return ir_container()->deterministicExprsMapOwnedBy(this); } // Collections access (unordered sets) const std::unordered_set& unordered_exprs() const noexcept { - return ir_container()->unordered_exprs(); + return ir_container()->exprsOwnedBy(this); } const std::unordered_set& vals() const noexcept { - return ir_container()->vals(); + return ir_container()->valsOwnedBy(this); } // Count queries diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index f0537714ce0..4765fab0662 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -75,6 +75,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); + + std::swap(a.per_fusion_vals_, b.per_fusion_vals_); + std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); } IrCloner IrContainer::copy( @@ -119,6 +122,8 @@ void IrContainer::clear() noexcept { exprs_up_.clear(); val_type_name_map_.clear(); expr_name_counter_ = 0; + per_fusion_vals_.clear(); + per_fusion_exprs_.clear(); } bool IrContainer::inContainer(const Statement* const_stmt) const { @@ -178,4 +183,130 @@ const std::unordered_set& IrContainer::sharingFusions() const { return sharing_fusions_; } +const std::unordered_set& IrContainer::valsOwnedBy( + const Fusion* fusion) const { + static const std::unordered_set empty; + auto it = per_fusion_vals_.find(fusion); + return it != per_fusion_vals_.end() ? it->second : empty; +} + +const std::unordered_set& IrContainer::exprsOwnedBy( + const Fusion* fusion) const { + static const std::unordered_set empty; + auto it = per_fusion_exprs_.find(fusion); + return it != per_fusion_exprs_.end() ? it->second : empty; +} + +void IrContainer::transferStatementOwnership( + const Fusion* from, + const Fusion* to) { + auto vals_it = per_fusion_vals_.find(from); + if (vals_it != per_fusion_vals_.end()) { + auto& to_vals = per_fusion_vals_[to]; + to_vals.insert(vals_it->second.begin(), vals_it->second.end()); + per_fusion_vals_.erase(vals_it); + } + + auto exprs_it = per_fusion_exprs_.find(from); + if (exprs_it != per_fusion_exprs_.end()) { + auto& to_exprs = per_fusion_exprs_[to]; + to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end()); + per_fusion_exprs_.erase(exprs_it); + } +} + +void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { + auto vals_it = per_fusion_vals_.find(fusion); + if (vals_it != per_fusion_vals_.end()) { + for (auto it = vals_up_.begin(); it != vals_up_.end();) { + if (vals_it->second.count(it->get()) > 0) { + vals_.erase(it->get()); + it = vals_up_.erase(it); + } else { + ++it; + } + } + per_fusion_vals_.erase(vals_it); + } + + auto exprs_it = per_fusion_exprs_.find(fusion); + if (exprs_it != per_fusion_exprs_.end()) { + for (auto it = exprs_up_.begin(); it != exprs_up_.end();) { + if (exprs_it->second.count(it->get()) > 0) { + exprs_.erase(it->get()); + it = exprs_up_.erase(it); + } else { + ++it; + } + } + per_fusion_exprs_.erase(exprs_it); + } +} + +std::deque IrContainer::deterministicValsOwnedBy( + const Fusion* fusion) const noexcept { + std::deque result; + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; + } + const auto& owned = it->second; + for (const auto& val_up : vals_up_) { + if (owned.count(val_up.get()) > 0) { + result.push_back(val_up.get()); + } + } + return result; +} + +std::deque IrContainer::deterministicExprsOwnedBy( + const Fusion* fusion) const noexcept { + std::deque result; + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; + } + const auto& owned = it->second; + for (const auto& expr_up : exprs_up_) { + if (owned.count(expr_up.get()) > 0) { + result.push_back(expr_up.get()); + } + } + return result; +} + +std::unordered_map IrContainer::deterministicValsMapOwnedBy( + const Fusion* fusion) const noexcept { + std::unordered_map result; + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; + } + const auto& owned = it->second; + int64_t count = 0; + for (const auto& val_up : vals_up_) { + if (owned.count(val_up.get()) > 0) { + result[val_up.get()] = count++; + } + } + return result; +} + +std::unordered_map IrContainer::deterministicExprsMapOwnedBy( + const Fusion* fusion) const noexcept { + std::unordered_map result; + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; + } + const auto& owned = it->second; + int64_t count = 0; + for (const auto& expr_up : exprs_up_) { + if (owned.count(expr_up.get()) > 0) { + result[expr_up.get()] = count++; + } + } + return result; +} + } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e3738be7349..ed0b4504840 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -137,8 +137,26 @@ class IrContainer { bool hasMultipleFusions() const; const std::unordered_set& sharingFusions() const; + NVF_API const std::unordered_set& valsOwnedBy( + const Fusion* fusion) const; + const std::unordered_set& exprsOwnedBy(const Fusion* fusion) const; + void transferStatementOwnership(const Fusion* from, const Fusion* to); + void removeStatementsOwnedBy(const Fusion* fusion); + + std::deque deterministicValsOwnedBy( + const Fusion* fusion) const noexcept; + std::deque deterministicExprsOwnedBy( + const Fusion* fusion) const noexcept; + std::unordered_map deterministicValsMapOwnedBy( + const Fusion* fusion) const noexcept; + std::unordered_map deterministicExprsMapOwnedBy( + const Fusion* fusion) const noexcept; + private: std::unordered_set sharing_fusions_; + std::unordered_map> per_fusion_vals_; + std::unordered_map> + per_fusion_exprs_; }; } // namespace nvfuser From 50cb886bfe45d49ea2ad012e90e6fd0125b7f167 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Thu, 26 Feb 2026 07:09:02 -0800 Subject: [PATCH 2/2] =?UTF-8?q?PR=20#5961=20Review=20Fixes=20=E2=80=94=20P?= =?UTF-8?q?er-Fusion=20Statement=20Tracking=20(#6015)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Review fixes for PR #5961 (Per-Fusion statement tracking): - **O(n²) → O(n)**: Optimize `removeStatementsOwnedBy` with `std::erase_if` - **Per-Fusion counts**: Convert `numExprs()`/`numVals()` to return per-Fusion counts instead of global - **StatementGuard fixes**: Snapshot and compare per-Fusion counts for correct LIFO rollback in shared containers - **LIFO assertions**: Verify tail elements belong to this Fusion before popping ## Tests All tests pass: - ✅ StatementGuardTest.ExecuteAfterGuard - ✅ StatementGuardTest.LazySpecialValsNotDangling - ✅ FusionCopy_CUDA - ✅ FusionMove_CUDA --- csrc/fusion.cpp | 28 +++++++++++----------------- csrc/fusion.h | 19 ++++++++++++++++--- csrc/ir/container.cpp | 28 ++++++++++++++-------------- csrc/statement_guard.cpp | 2 +- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 61d950995e2..fb7a6a72b14 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -426,24 +426,14 @@ void Fusion::removeStatementsCreatedAfter( int64_t num_vals_before) { auto* c = ir_container(); - NVF_ERROR( - c->exprs_up_.size() == c->exprs_.size(), - "exprs_up_ (size ", - c->exprs_up_.size(), - ") and exprs_ (size ", - c->exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(c->exprs_up_) >= num_exprs_before, - "exprs_up_ size (", - std::ssize(c->exprs_up_), - ") is less than num_exprs_before (", - num_exprs_before, - ")."); - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->exprs_up_) > num_exprs_before) { + while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) { + // Pop from global deque back — statements created by this Fusion during + // the guard scope are at the tail (LIFO invariant). Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[this].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); for (Val* in : e->inputs()) { in->removeUse(e); } @@ -452,8 +442,12 @@ void Fusion::removeStatementsCreatedAfter( c->exprs_up_.pop_back(); } - while (std::ssize(c->vals_up_) > num_vals_before) { + while (numValsExcludingShortcuts() > num_vals_before) { Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[this].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + // Null out shortcut caches if they point to vals about to be destroyed if (v == zero_val_) { zero_val_ = nullptr; } else if (v == one_val_) { diff --git a/csrc/fusion.h b/csrc/fusion.h index 07353bfe08c..e44817405e7 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -547,13 +547,26 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container()->valsOwnedBy(this); } - // Count queries + // Count queries (per-Fusion: only counts statements owned by this Fusion) int64_t numExprs() const noexcept { - return ir_container()->numExprs(); + return std::ssize(ir_container()->exprsOwnedBy(this)); } int64_t numVals() const noexcept { - return ir_container()->numVals(); + return std::ssize(ir_container()->valsOwnedBy(this)); + } + + //! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.). + //! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but + //! since they're singletons that should persist across StatementGuard scopes, + //! this count excludes them so the LIFO pop-back in + //! removeStatementsCreatedAfter correctly skips over them. + int64_t numValsExcludingShortcuts() const noexcept { + int64_t count = std::ssize(ir_container()->valsOwnedBy(this)); + count -= (zero_val_ != nullptr) + (one_val_ != nullptr) + + (true_val_ != nullptr) + (false_val_ != nullptr) + + (magic_zero_val_ != nullptr); + return count; } // Shortcut values (frequently used constants) diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 4765fab0662..5c0b2e0f5a6 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -218,27 +218,27 @@ void IrContainer::transferStatementOwnership( void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { auto vals_it = per_fusion_vals_.find(fusion); if (vals_it != per_fusion_vals_.end()) { - for (auto it = vals_up_.begin(); it != vals_up_.end();) { - if (vals_it->second.count(it->get()) > 0) { - vals_.erase(it->get()); - it = vals_up_.erase(it); - } else { - ++it; + const auto& owned = vals_it->second; + std::erase_if(vals_up_, [&](const std::unique_ptr& v) { + if (owned.count(v.get()) > 0) { + vals_.erase(v.get()); + return true; } - } + return false; + }); per_fusion_vals_.erase(vals_it); } auto exprs_it = per_fusion_exprs_.find(fusion); if (exprs_it != per_fusion_exprs_.end()) { - for (auto it = exprs_up_.begin(); it != exprs_up_.end();) { - if (exprs_it->second.count(it->get()) > 0) { - exprs_.erase(it->get()); - it = exprs_up_.erase(it); - } else { - ++it; + const auto& owned = exprs_it->second; + std::erase_if(exprs_up_, [&](const std::unique_ptr& e) { + if (owned.count(e.get()) > 0) { + exprs_.erase(e.get()); + return true; } - } + return false; + }); per_fusion_exprs_.erase(exprs_it); } } diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 4575bb59076..15a3b4159c3 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numVals()) {} + prev_num_vals_(fusion_->numValsExcludingShortcuts()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);