Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,24 +427,14 @@ void Fusion::removeStatementsCreatedAfter(
int64_t num_vals_before) {
auto* c = ir_container();

NVF_ERROR(
c->exprs_up_.size() == c->exprs_.size(),
"exprs_up_ (size ",
c->exprs_up_.size(),
") and exprs_ (size ",
c->exprs_.size(),
") are out of sync.");
NVF_ERROR(
std::ssize(c->exprs_up_) >= num_exprs_before,
"exprs_up_ size (",
std::ssize(c->exprs_up_),
") is less than num_exprs_before (",
num_exprs_before,
").");

// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->exprs_up_) > num_exprs_before) {
while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) {
// Pop from global deque back — statements created by this Fusion during
// the guard scope are at the tail (LIFO invariant).
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[this].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
for (Val* in : e->inputs()) {
in->removeUse(e);
}
Expand All @@ -453,8 +443,12 @@ void Fusion::removeStatementsCreatedAfter(
c->exprs_up_.pop_back();
}

while (std::ssize(c->vals_up_) > num_vals_before) {
while (numValsExcludingShortcuts() > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[this].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
// Null out shortcut caches if they point to vals about to be destroyed
if (v == zero_val_) {
zero_val_ = nullptr;
} else if (v == one_val_) {
Expand Down
19 changes: 16 additions & 3 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,13 +546,26 @@ class NVF_API Fusion : public PolymorphicBase {
return ir_container()->valsOwnedBy(this);
}

// Count queries
// Count queries (per-Fusion: only counts statements owned by this Fusion)
int64_t numExprs() const noexcept {
return ir_container()->numExprs();
return std::ssize(ir_container()->exprsOwnedBy(this));
}

int64_t numVals() const noexcept {
return ir_container()->numVals();
return std::ssize(ir_container()->valsOwnedBy(this));
}

//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
//! since they're singletons that should persist across StatementGuard scopes,
//! this count excludes them so the LIFO pop-back in
//! removeStatementsCreatedAfter correctly skips over them.
int64_t numValsExcludingShortcuts() const noexcept {
int64_t count = std::ssize(ir_container()->valsOwnedBy(this));
count -= (zero_val_ != nullptr) + (one_val_ != nullptr) +
(true_val_ != nullptr) + (false_val_ != nullptr) +
(magic_zero_val_ != nullptr);
return count;
}

// Shortcut values (frequently used constants)
Expand Down
28 changes: 14 additions & 14 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,27 @@ void IrContainer::transferStatementOwnership(
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
auto vals_it = per_fusion_vals_.find(fusion);
if (vals_it != per_fusion_vals_.end()) {
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
if (vals_it->second.count(it->get()) > 0) {
vals_.erase(it->get());
it = vals_up_.erase(it);
} else {
++it;
const auto& owned = vals_it->second;
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
if (owned.count(v.get()) > 0) {
vals_.erase(v.get());
return true;
}
}
return false;
});
per_fusion_vals_.erase(vals_it);
}

auto exprs_it = per_fusion_exprs_.find(fusion);
if (exprs_it != per_fusion_exprs_.end()) {
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
if (exprs_it->second.count(it->get()) > 0) {
exprs_.erase(it->get());
it = exprs_up_.erase(it);
} else {
++it;
const auto& owned = exprs_it->second;
std::erase_if(exprs_up_, [&](const std::unique_ptr<Expr>& e) {
if (owned.count(e.get()) > 0) {
exprs_.erase(e.get());
return true;
}
}
return false;
});
per_fusion_exprs_.erase(exprs_it);
}
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/statement_guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
return fusion;
}()),
prev_num_exprs_(fusion_->numExprs()),
prev_num_vals_(fusion_->numVals()) {}
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}

StatementGuard::~StatementGuard() {
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);
Expand Down
Loading