Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {

// After swapping container contents, update Statement::ir_container_
// pointers so each Statement points to the Fusion whose container now
// holds it.
// holds it. Also fix per-Fusion tracking keys since a's container had
// b's entries and vice versa.
a.ir_container()->transferStatementOwnership(&b, &a);
b.ir_container()->transferStatementOwnership(&a, &b);

if (a.ir_container_) {
for (auto val : a.vals()) {
val->ir_container_ = &a;
Expand Down Expand Up @@ -295,10 +299,9 @@ void Fusion::clear() noexcept {
// constructor of Trace, which could throw an exception.
// FUSER_PERF_SCOPE("Fusion clear");

// Clear container contents instead of destroying it
// This preserves the container object so Statement pointers don't become
// dangling
ir_container()->clear();
if (ir_container_) {
ir_container_->removeStatementsOwnedBy(this);
}

inputs_.clear();
outputs_.clear();
Expand All @@ -308,8 +311,6 @@ void Fusion::clear() noexcept {
managed_data_.clear();
managed_named_data_.clear();

// Reset per-Fusion special value caches (the vals themselves are owned by
// ir_container and were already destroyed by ir_container()->clear() above).
zero_val_ = nullptr;
one_val_ = nullptr;
true_val_ = nullptr;
Expand Down Expand Up @@ -353,6 +354,7 @@ void Fusion::removeExpr(Expr* expr) {
NVF_ERROR(
expr_in_deque != c->exprs_up_.end(),
"Wanted to remove an expression but its unique ptr is missing.");
c->per_fusion_exprs_[this].erase(expr);
c->exprs_.erase(expr);
c->exprs_up_.erase(expr_in_deque);
}
Expand Down Expand Up @@ -412,6 +414,7 @@ void Fusion::removeVal(Val* val) {
NVF_ERROR(
val_in_deque != c->vals_up_.end(),
"Wanted to remove a value but its unique ptr is missing.");
c->per_fusion_vals_[this].erase(val);
c->vals_.erase(val);
c->vals_up_.erase(val_in_deque);

Expand All @@ -423,36 +426,28 @@ void Fusion::removeStatementsCreatedAfter(
int64_t num_vals_before) {
auto* c = ir_container();

NVF_ERROR(
c->exprs_up_.size() == c->exprs_.size(),
"exprs_up_ (size ",
c->exprs_up_.size(),
") and exprs_ (size ",
c->exprs_.size(),
") are out of sync.");
NVF_ERROR(
std::ssize(c->exprs_up_) >= num_exprs_before,
"exprs_up_ size (",
std::ssize(c->exprs_up_),
") is less than num_exprs_before (",
num_exprs_before,
").");

// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->exprs_up_) > num_exprs_before) {
while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) {
// Pop from global deque back — statements created by this Fusion during
// the guard scope are at the tail (LIFO invariant).
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[this].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[this].erase(e);
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}

// Null out any special value caches that point to vals about to be destroyed.
// This prevents dangling pointers when special vals are lazily created inside
// a StatementGuard scope.
while (std::ssize(c->vals_up_) > num_vals_before) {
while (numValsExcludingShortcuts() > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[this].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
// Null out shortcut caches if they point to vals about to be destroyed
if (v == zero_val_) {
zero_val_ = nullptr;
} else if (v == one_val_) {
Expand All @@ -464,6 +459,7 @@ void Fusion::removeStatementsCreatedAfter(
} else if (v == magic_zero_val_) {
magic_zero_val_ = nullptr;
}
c->per_fusion_vals_[this].erase(v);
c->vals_.erase(v);
c->vals_up_.pop_back();
}
Expand Down Expand Up @@ -927,6 +923,7 @@ void Fusion::registerVal(Val* val) {
auto* c = ir_container();
c->vals_up_.emplace_back(val);
c->vals_.insert(val);
c->per_fusion_vals_[this].insert(val);
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
}

Expand All @@ -943,6 +940,7 @@ void Fusion::registerExpr(Expr* expr) {
auto* c = ir_container();
c->exprs_up_.emplace_back(expr);
c->exprs_.insert(expr);
c->per_fusion_exprs_[this].insert(expr);
expr->setName(IrContainerPasskey(), c->getExprName());

for (Val* input : expr->inputs()) {
Expand Down
41 changes: 26 additions & 15 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,40 +522,51 @@ class NVF_API Fusion : public PolymorphicBase {
}

// Collections access (return values in insertion order)
const std::deque<Val*> deterministic_vals() const noexcept {
return ir_container()->deterministic_vals();
std::deque<Val*> deterministic_vals() const noexcept {
return ir_container()->deterministicValsOwnedBy(this);
}

const std::deque<Expr*> deterministic_exprs() const noexcept {
return ir_container()->deterministic_exprs();
std::deque<Expr*> deterministic_exprs() const noexcept {
return ir_container()->deterministicExprsOwnedBy(this);
}

const std::unordered_map<Val*, int64_t> deterministic_vals_map()
const noexcept {
return ir_container()->deterministic_vals_map();
std::unordered_map<Val*, int64_t> deterministic_vals_map() const noexcept {
return ir_container()->deterministicValsMapOwnedBy(this);
}

const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
const noexcept {
return ir_container()->deterministic_exprs_map();
std::unordered_map<Expr*, int64_t> deterministic_exprs_map() const noexcept {
return ir_container()->deterministicExprsMapOwnedBy(this);
}

// Collections access (unordered sets)
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
return ir_container()->unordered_exprs();
return ir_container()->exprsOwnedBy(this);
}

const std::unordered_set<Val*>& vals() const noexcept {
return ir_container()->vals();
return ir_container()->valsOwnedBy(this);
}

// Count queries
// Count queries (per-Fusion: only counts statements owned by this Fusion)
int64_t numExprs() const noexcept {
return ir_container()->numExprs();
return std::ssize(ir_container()->exprsOwnedBy(this));
}

int64_t numVals() const noexcept {
return ir_container()->numVals();
return std::ssize(ir_container()->valsOwnedBy(this));
}

//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
//! since they're singletons that should persist across StatementGuard scopes,
//! this count excludes them so the LIFO pop-back in
//! removeStatementsCreatedAfter correctly skips over them.
int64_t numValsExcludingShortcuts() const noexcept {
int64_t count = std::ssize(ir_container()->valsOwnedBy(this));
count -= (zero_val_ != nullptr) + (one_val_ != nullptr) +
(true_val_ != nullptr) + (false_val_ != nullptr) +
(magic_zero_val_ != nullptr);
return count;
}

// Shortcut values (frequently used constants)
Expand Down
131 changes: 131 additions & 0 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {

std::swap(a.val_type_name_map_, b.val_type_name_map_);
std::swap(a.expr_name_counter_, b.expr_name_counter_);

std::swap(a.per_fusion_vals_, b.per_fusion_vals_);
std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_);
}

IrCloner IrContainer::copy(
Expand Down Expand Up @@ -119,6 +122,8 @@ void IrContainer::clear() noexcept {
exprs_up_.clear();
val_type_name_map_.clear();
expr_name_counter_ = 0;
per_fusion_vals_.clear();
per_fusion_exprs_.clear();
}

bool IrContainer::inContainer(const Statement* const_stmt) const {
Expand Down Expand Up @@ -178,4 +183,130 @@ const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
return sharing_fusions_;
}

const std::unordered_set<Val*>& IrContainer::valsOwnedBy(
const Fusion* fusion) const {
static const std::unordered_set<Val*> empty;
auto it = per_fusion_vals_.find(fusion);
return it != per_fusion_vals_.end() ? it->second : empty;
}

const std::unordered_set<Expr*>& IrContainer::exprsOwnedBy(
const Fusion* fusion) const {
static const std::unordered_set<Expr*> empty;
auto it = per_fusion_exprs_.find(fusion);
return it != per_fusion_exprs_.end() ? it->second : empty;
}

void IrContainer::transferStatementOwnership(
const Fusion* from,
const Fusion* to) {
auto vals_it = per_fusion_vals_.find(from);
if (vals_it != per_fusion_vals_.end()) {
auto& to_vals = per_fusion_vals_[to];
to_vals.insert(vals_it->second.begin(), vals_it->second.end());
per_fusion_vals_.erase(vals_it);
}

auto exprs_it = per_fusion_exprs_.find(from);
if (exprs_it != per_fusion_exprs_.end()) {
auto& to_exprs = per_fusion_exprs_[to];
to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end());
per_fusion_exprs_.erase(exprs_it);
}
}

void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
auto vals_it = per_fusion_vals_.find(fusion);
if (vals_it != per_fusion_vals_.end()) {
const auto& owned = vals_it->second;
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
if (owned.count(v.get()) > 0) {
vals_.erase(v.get());
return true;
}
return false;
});
per_fusion_vals_.erase(vals_it);
}

auto exprs_it = per_fusion_exprs_.find(fusion);
if (exprs_it != per_fusion_exprs_.end()) {
const auto& owned = exprs_it->second;
std::erase_if(exprs_up_, [&](const std::unique_ptr<Expr>& e) {
if (owned.count(e.get()) > 0) {
exprs_.erase(e.get());
return true;
}
return false;
});
per_fusion_exprs_.erase(exprs_it);
}
}

std::deque<Val*> IrContainer::deterministicValsOwnedBy(
const Fusion* fusion) const noexcept {
std::deque<Val*> result;
auto it = per_fusion_vals_.find(fusion);
if (it == per_fusion_vals_.end()) {
return result;
}
const auto& owned = it->second;
for (const auto& val_up : vals_up_) {
if (owned.count(val_up.get()) > 0) {
result.push_back(val_up.get());
}
}
return result;
}

std::deque<Expr*> IrContainer::deterministicExprsOwnedBy(
const Fusion* fusion) const noexcept {
std::deque<Expr*> result;
auto it = per_fusion_exprs_.find(fusion);
if (it == per_fusion_exprs_.end()) {
return result;
}
const auto& owned = it->second;
for (const auto& expr_up : exprs_up_) {
if (owned.count(expr_up.get()) > 0) {
result.push_back(expr_up.get());
}
}
return result;
}

std::unordered_map<Val*, int64_t> IrContainer::deterministicValsMapOwnedBy(
const Fusion* fusion) const noexcept {
std::unordered_map<Val*, int64_t> result;
auto it = per_fusion_vals_.find(fusion);
if (it == per_fusion_vals_.end()) {
return result;
}
const auto& owned = it->second;
int64_t count = 0;
for (const auto& val_up : vals_up_) {
if (owned.count(val_up.get()) > 0) {
result[val_up.get()] = count++;
}
}
return result;
}

std::unordered_map<Expr*, int64_t> IrContainer::deterministicExprsMapOwnedBy(
const Fusion* fusion) const noexcept {
std::unordered_map<Expr*, int64_t> result;
auto it = per_fusion_exprs_.find(fusion);
if (it == per_fusion_exprs_.end()) {
return result;
}
const auto& owned = it->second;
int64_t count = 0;
for (const auto& expr_up : exprs_up_) {
if (owned.count(expr_up.get()) > 0) {
result[expr_up.get()] = count++;
}
}
return result;
}

} // namespace nvfuser
18 changes: 18 additions & 0 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,26 @@ class IrContainer {
bool hasMultipleFusions() const;
const std::unordered_set<Fusion*>& sharingFusions() const;

NVF_API const std::unordered_set<Val*>& valsOwnedBy(
const Fusion* fusion) const;
const std::unordered_set<Expr*>& exprsOwnedBy(const Fusion* fusion) const;
void transferStatementOwnership(const Fusion* from, const Fusion* to);
void removeStatementsOwnedBy(const Fusion* fusion);

std::deque<Val*> deterministicValsOwnedBy(
const Fusion* fusion) const noexcept;
std::deque<Expr*> deterministicExprsOwnedBy(
const Fusion* fusion) const noexcept;
std::unordered_map<Val*, int64_t> deterministicValsMapOwnedBy(
const Fusion* fusion) const noexcept;
std::unordered_map<Expr*, int64_t> deterministicExprsMapOwnedBy(
const Fusion* fusion) const noexcept;

private:
std::unordered_set<Fusion*> sharing_fusions_;
std::unordered_map<const Fusion*, std::unordered_set<Val*>> per_fusion_vals_;
std::unordered_map<const Fusion*, std::unordered_set<Expr*>>
per_fusion_exprs_;
};

} // namespace nvfuser
2 changes: 1 addition & 1 deletion csrc/statement_guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
return fusion;
}()),
prev_num_exprs_(fusion_->numExprs()),
prev_num_vals_(fusion_->numVals()) {}
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}

StatementGuard::~StatementGuard() {
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);
Expand Down