diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 01f1c300f09..fb7a6a72b14 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -114,7 +114,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { // After swapping container contents, update Statement::ir_container_ // pointers so each Statement points to the Fusion whose container now - // holds it. + // holds it. Also fix per-Fusion tracking keys since a's container had + // b's entries and vice versa. + a.ir_container()->transferStatementOwnership(&b, &a); + b.ir_container()->transferStatementOwnership(&a, &b); + if (a.ir_container_) { for (auto val : a.vals()) { val->ir_container_ = &a; @@ -295,10 +299,9 @@ void Fusion::clear() noexcept { // constructor of Trace, which could throw an exception. // FUSER_PERF_SCOPE("Fusion clear"); - // Clear container contents instead of destroying it - // This preserves the container object so Statement pointers don't become - // dangling - ir_container()->clear(); + if (ir_container_) { + ir_container_->removeStatementsOwnedBy(this); + } inputs_.clear(); outputs_.clear(); @@ -308,8 +311,6 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - // Reset per-Fusion special value caches (the vals themselves are owned by - // ir_container and were already destroyed by ir_container()->clear() above). zero_val_ = nullptr; one_val_ = nullptr; true_val_ = nullptr; @@ -353,6 +354,7 @@ void Fusion::removeExpr(Expr* expr) { NVF_ERROR( expr_in_deque != c->exprs_up_.end(), "Wanted to remove an expression but its unique ptr is missing."); + c->per_fusion_exprs_[this].erase(expr); c->exprs_.erase(expr); c->exprs_up_.erase(expr_in_deque); } @@ -412,6 +414,7 @@ void Fusion::removeVal(Val* val) { NVF_ERROR( val_in_deque != c->vals_up_.end(), "Wanted to remove a value but its unique ptr is missing."); + c->per_fusion_vals_[this].erase(val); c->vals_.erase(val); c->vals_up_.erase(val_in_deque); @@ -423,36 +426,28 @@ void Fusion::removeStatementsCreatedAfter( int64_t num_vals_before) { auto* c = ir_container(); - NVF_ERROR( - c->exprs_up_.size() == c->exprs_.size(), - "exprs_up_ (size ", - c->exprs_up_.size(), - ") and exprs_ (size ", - c->exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(c->exprs_up_) >= num_exprs_before, - "exprs_up_ size (", - std::ssize(c->exprs_up_), - ") is less than num_exprs_before (", - num_exprs_before, - ")."); - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->exprs_up_) > num_exprs_before) { + while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) { + // Pop from global deque back — statements created by this Fusion during + // the guard scope are at the tail (LIFO invariant). Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[this].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); for (Val* in : e->inputs()) { in->removeUse(e); } + c->per_fusion_exprs_[this].erase(e); c->exprs_.erase(e); c->exprs_up_.pop_back(); } - // Null out any special value caches that point to vals about to be destroyed. - // This prevents dangling pointers when special vals are lazily created inside - // a StatementGuard scope. - while (std::ssize(c->vals_up_) > num_vals_before) { + while (numValsExcludingShortcuts() > num_vals_before) { Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[this].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + // Null out shortcut caches if they point to vals about to be destroyed if (v == zero_val_) { zero_val_ = nullptr; } else if (v == one_val_) { @@ -464,6 +459,7 @@ void Fusion::removeStatementsCreatedAfter( } else if (v == magic_zero_val_) { magic_zero_val_ = nullptr; } + c->per_fusion_vals_[this].erase(v); c->vals_.erase(v); c->vals_up_.pop_back(); } @@ -927,6 +923,7 @@ void Fusion::registerVal(Val* val) { auto* c = ir_container(); c->vals_up_.emplace_back(val); c->vals_.insert(val); + c->per_fusion_vals_[this].insert(val); val->setName(IrContainerPasskey(), c->getValName(val->vtype())); } @@ -943,6 +940,7 @@ void Fusion::registerExpr(Expr* expr) { auto* c = ir_container(); c->exprs_up_.emplace_back(expr); c->exprs_.insert(expr); + c->per_fusion_exprs_[this].insert(expr); expr->setName(IrContainerPasskey(), c->getExprName()); for (Val* input : expr->inputs()) { diff --git a/csrc/fusion.h b/csrc/fusion.h index 6ba7a3788df..e44817405e7 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -522,40 +522,51 @@ class NVF_API Fusion : public PolymorphicBase { } // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_container()->deterministic_vals(); + std::deque deterministic_vals() const noexcept { + return ir_container()->deterministicValsOwnedBy(this); } - const std::deque deterministic_exprs() const noexcept { - return ir_container()->deterministic_exprs(); + std::deque deterministic_exprs() const noexcept { + return ir_container()->deterministicExprsOwnedBy(this); } - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_container()->deterministic_vals_map(); + std::unordered_map deterministic_vals_map() const noexcept { + return ir_container()->deterministicValsMapOwnedBy(this); } - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_container()->deterministic_exprs_map(); + std::unordered_map deterministic_exprs_map() const noexcept { + return ir_container()->deterministicExprsMapOwnedBy(this); } // Collections access (unordered sets) const std::unordered_set& unordered_exprs() const noexcept { - return ir_container()->unordered_exprs(); + return ir_container()->exprsOwnedBy(this); } const std::unordered_set& vals() const noexcept { - return ir_container()->vals(); + return ir_container()->valsOwnedBy(this); } - // Count queries + // Count queries (per-Fusion: only counts statements owned by this Fusion) int64_t numExprs() const noexcept { - return ir_container()->numExprs(); + return std::ssize(ir_container()->exprsOwnedBy(this)); } int64_t numVals() const noexcept { - return ir_container()->numVals(); + return std::ssize(ir_container()->valsOwnedBy(this)); + } + + //! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.). + //! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but + //! since they're singletons that should persist across StatementGuard scopes, + //! this count excludes them so the LIFO pop-back in + //! removeStatementsCreatedAfter correctly skips over them. + int64_t numValsExcludingShortcuts() const noexcept { + int64_t count = std::ssize(ir_container()->valsOwnedBy(this)); + count -= (zero_val_ != nullptr) + (one_val_ != nullptr) + + (true_val_ != nullptr) + (false_val_ != nullptr) + + (magic_zero_val_ != nullptr); + return count; } // Shortcut values (frequently used constants) diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index f0537714ce0..5c0b2e0f5a6 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -75,6 +75,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); + + std::swap(a.per_fusion_vals_, b.per_fusion_vals_); + std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); } IrCloner IrContainer::copy( @@ -119,6 +122,8 @@ void IrContainer::clear() noexcept { exprs_up_.clear(); val_type_name_map_.clear(); expr_name_counter_ = 0; + per_fusion_vals_.clear(); + per_fusion_exprs_.clear(); } bool IrContainer::inContainer(const Statement* const_stmt) const { @@ -178,4 +183,130 @@ const std::unordered_set& IrContainer::sharingFusions() const { return sharing_fusions_; } +const std::unordered_set& IrContainer::valsOwnedBy( + const Fusion* fusion) const { + static const std::unordered_set empty; + auto it = per_fusion_vals_.find(fusion); + return it != per_fusion_vals_.end() ? it->second : empty; +} + +const std::unordered_set& IrContainer::exprsOwnedBy( + const Fusion* fusion) const { + static const std::unordered_set empty; + auto it = per_fusion_exprs_.find(fusion); + return it != per_fusion_exprs_.end() ? it->second : empty; +} + +void IrContainer::transferStatementOwnership( + const Fusion* from, + const Fusion* to) { + auto vals_it = per_fusion_vals_.find(from); + if (vals_it != per_fusion_vals_.end()) { + auto& to_vals = per_fusion_vals_[to]; + to_vals.insert(vals_it->second.begin(), vals_it->second.end()); + per_fusion_vals_.erase(vals_it); + } + + auto exprs_it = per_fusion_exprs_.find(from); + if (exprs_it != per_fusion_exprs_.end()) { + auto& to_exprs = per_fusion_exprs_[to]; + to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end()); + per_fusion_exprs_.erase(exprs_it); + } +} + +void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { + auto vals_it = per_fusion_vals_.find(fusion); + if (vals_it != per_fusion_vals_.end()) { + const auto& owned = vals_it->second; + std::erase_if(vals_up_, [&](const std::unique_ptr& v) { + if (owned.count(v.get()) > 0) { + vals_.erase(v.get()); + return true; + } + return false; + }); + per_fusion_vals_.erase(vals_it); + } + + auto exprs_it = per_fusion_exprs_.find(fusion); + if (exprs_it != per_fusion_exprs_.end()) { + const auto& owned = exprs_it->second; + std::erase_if(exprs_up_, [&](const std::unique_ptr& e) { + if (owned.count(e.get()) > 0) { + exprs_.erase(e.get()); + return true; + } + return false; + }); + per_fusion_exprs_.erase(exprs_it); + } +} + +std::deque IrContainer::deterministicValsOwnedBy( + const Fusion* fusion) const noexcept { + std::deque result; + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; + } + const auto& owned = it->second; + for (const auto& val_up : vals_up_) { + if (owned.count(val_up.get()) > 0) { + result.push_back(val_up.get()); + } + } + return result; +} + +std::deque IrContainer::deterministicExprsOwnedBy( + const Fusion* fusion) const noexcept { + std::deque result; + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; + } + const auto& owned = it->second; + for (const auto& expr_up : exprs_up_) { + if (owned.count(expr_up.get()) > 0) { + result.push_back(expr_up.get()); + } + } + return result; +} + +std::unordered_map IrContainer::deterministicValsMapOwnedBy( + const Fusion* fusion) const noexcept { + std::unordered_map result; + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; + } + const auto& owned = it->second; + int64_t count = 0; + for (const auto& val_up : vals_up_) { + if (owned.count(val_up.get()) > 0) { + result[val_up.get()] = count++; + } + } + return result; +} + +std::unordered_map IrContainer::deterministicExprsMapOwnedBy( + const Fusion* fusion) const noexcept { + std::unordered_map result; + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; + } + const auto& owned = it->second; + int64_t count = 0; + for (const auto& expr_up : exprs_up_) { + if (owned.count(expr_up.get()) > 0) { + result[expr_up.get()] = count++; + } + } + return result; +} + } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e3738be7349..ed0b4504840 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -137,8 +137,26 @@ class IrContainer { bool hasMultipleFusions() const; const std::unordered_set& sharingFusions() const; + NVF_API const std::unordered_set& valsOwnedBy( + const Fusion* fusion) const; + const std::unordered_set& exprsOwnedBy(const Fusion* fusion) const; + void transferStatementOwnership(const Fusion* from, const Fusion* to); + void removeStatementsOwnedBy(const Fusion* fusion); + + std::deque deterministicValsOwnedBy( + const Fusion* fusion) const noexcept; + std::deque deterministicExprsOwnedBy( + const Fusion* fusion) const noexcept; + std::unordered_map deterministicValsMapOwnedBy( + const Fusion* fusion) const noexcept; + std::unordered_map deterministicExprsMapOwnedBy( + const Fusion* fusion) const noexcept; + private: std::unordered_set sharing_fusions_; + std::unordered_map> per_fusion_vals_; + std::unordered_map> + per_fusion_exprs_; }; } // namespace nvfuser diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 4575bb59076..15a3b4159c3 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numVals()) {} + prev_num_vals_(fusion_->numValsExcludingShortcuts()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);