Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 95 additions & 50 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,19 @@ struct Fusion::ContainerMutator {
}
}

static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept {
auto* c = self->ir_container();
// Use direct field access. Avoids re-entering valsOwnedBy() which acquires
// shared_lock.
const auto it = c->per_fusion_vals_.find(self);
int64_t count = it != c->per_fusion_vals_.end()
? static_cast<int64_t>(it->second.size())
: 0;
count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) +
(self->true_val_ != nullptr) + (self->false_val_ != nullptr) +
(self->magic_zero_val_ != nullptr);
return count;
// Null out self's shortcut-val pointer cache if v is one of them.
static void nullOutShortcutIfNeeded(Fusion* self, Val* v) {
if (v == self->zero_val_) {
self->zero_val_ = nullptr;
} else if (v == self->one_val_) {
self->one_val_ = nullptr;
} else if (v == self->true_val_) {
self->true_val_ = nullptr;
} else if (v == self->false_val_) {
self->false_val_ = nullptr;
} else if (v == self->magic_zero_val_) {
self->magic_zero_val_ = nullptr;
}
}

static void removeStatementsCreatedAfter(
Expand All @@ -201,42 +202,84 @@ struct Fusion::ContainerMutator {
int64_t num_vals_before) {
auto* c = self->ir_container();

// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
// Pop from global deque back — statements created by this Fusion during
// the guard scope are at the tail (LIFO invariant).
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
for (Val* in : e->inputs()) {
in->removeUse(e);
// Use direct field access — hasMultipleFusions() acquires shared_lock which
// deadlocks since the caller already holds unique_lock on mutex_.
if (c->sharing_fusions_.size() <= 1) {
// Fast path: single Fusion owns this container, so the LIFO invariant
// holds — self's newest statements are always at the global deque tail.
// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another "
"Fusion");
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}

while (numValsExcludingShortcuts(self) > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[self].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
// Null out shortcut caches if they point to vals about to be destroyed
if (v == self->zero_val_) {
self->zero_val_ = nullptr;
} else if (v == self->one_val_) {
self->one_val_ = nullptr;
} else if (v == self->true_val_) {
self->true_val_ = nullptr;
} else if (v == self->false_val_) {
self->false_val_ = nullptr;
} else if (v == self->magic_zero_val_) {
self->magic_zero_val_ = nullptr;
while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[self].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
c->vals_up_.pop_back();
}
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
c->vals_up_.pop_back();
} else {
// Slow path: shared container — other Fusions' statements may be
// interleaved at the tail of the global deques. Use std::erase_if
// (C++20) to scan forward: skip the first num_before of self's
// statements (old, to keep), then erase the remainder (added during
// the guard scope). Entered whenever the container is shared,
// regardless of success or failure; if no new statements were added
// the scan completes trivially. O(total statements in container).
int64_t exprs_kept = 0;
std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
Expr* e = e_up.get();
if (c->per_fusion_exprs_[self].count(e) == 0) {
return false; // belongs to another Fusion — keep
}
if (exprs_kept < num_exprs_before) {
++exprs_kept;
return false; // self's old expr — keep
}
// self's new expr — remove (clean up uses and index maps first)
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
return true;
});

int64_t vals_kept = 0;
std::erase_if(c->vals_up_, [&](const std::unique_ptr<Val>& v_up) {
Val* v = v_up.get();
if (c->per_fusion_vals_[self].count(v) == 0) {
return false; // belongs to another Fusion — keep
Comment on lines +248 to +271
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operator[] on outer map can insert spurious empty entries

Both slow-path predicates access the inner per-fusion sets via operator[] on the outer unordered_map:

if (c->per_fusion_exprs_[self].count(e) == 0) { … }
// and later:
if (c->per_fusion_vals_[self].count(v) == 0) { … }

operator[] on an unordered_map performs a key insertion (with a default-constructed value) when the key is absent. The key for self can legitimately be absent here because Fusion::copy()to->clear()removeStatementsOwnedBy(to) erases it, and it is only re-inserted by the first registerExpr/registerVal call. If a StatementGuard is destroyed in the zero-statements case (guard scope added nothing and the fusion was just cleared), the slow path will be entered — c->sharing_fusions_.size() > 1 — and the predicate will call operator[] on a missing key, inserting a spurious empty set.

The behavior is still logically correct (count returns 0, predicate returns false, nothing is erased), but it silently re-populates a key that removeStatementsOwnedBy purposely removed. Use .find() to avoid the side-effect:

auto exprs_it = c->per_fusion_exprs_.find(self);
std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
    Expr* e = e_up.get();
    if (exprs_it == c->per_fusion_exprs_.end() ||
        exprs_it->second.count(e) == 0) {
        return false; // belongs to another Fusion — keep
    }
    …
});

And similarly for the vals_up_ pass.

}
if (vals_kept < num_vals_before) {
++vals_kept;
return false; // self's old val — keep
}
// self's new val — remove (null shortcut cache pointer if applicable)
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
return true;
});
}
}
};
Expand Down Expand Up @@ -520,6 +563,12 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
ir_container_->addFusion(this);
}

// Shared-container constructor -- creates empty Fusion using existing container
Fusion::Fusion(std::shared_ptr<IrContainer> container)
: ir_container_(std::move(container)) {
ir_container_->addFusion(this);
}

// Copy constructor -- shares the source's container
Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) {
FUSER_PERF_SCOPE("Fusion copy");
Expand Down Expand Up @@ -620,10 +669,6 @@ void Fusion::removeStatementsCreatedAfter(
this, num_exprs_before, num_vals_before);
}

int64_t Fusion::numValsExcludingShortcuts() const noexcept {
return ContainerMutator::numValsExcludingShortcuts(this);
}

void Fusion::addInput(Val* input) {
assertInContainer(input, "Cannot register input ");

Expand Down
11 changes: 4 additions & 7 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,6 @@ class NVF_API Fusion : public PolymorphicBase {
return std::ssize(ir_container()->valsOwnedBy(this));
}

//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
//! since they're singletons that should persist across StatementGuard scopes,
//! this count excludes them so the LIFO pop-back in
//! removeStatementsCreatedAfter correctly skips over them.
int64_t numValsExcludingShortcuts() const noexcept;

// Shortcut values (frequently used constants)
Val* zeroVal();
Val* oneVal();
Expand Down Expand Up @@ -600,6 +593,10 @@ class NVF_API Fusion : public PolymorphicBase {
friend class TranslateApplicableWelford;
friend Val;

//! Constructor that shares an existing container. Creates an empty Fusion
//! registered with the shared container. Used by makeFusion for sharing.
explicit Fusion(std::shared_ptr<IrContainer> container);

//! Register the Val with this fusion
virtual void registerVal(Val* val);

Expand Down
3 changes: 2 additions & 1 deletion csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,8 @@ std::pair<IrCloner, std::unique_ptr<Fusion>> SegmentedFusion::makeFusion(
SegmentedGroup* sg) const {
// TODO Optimize cloning step by only copying values and expressions between
// the fusion segment's inputs and outputs.
auto fusion_segment = std::make_unique<Fusion>();
auto fusion_segment =
std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr()));
Comment on lines +1804 to +1805
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raw new without explanatory comment

Using new Fusion(...) directly instead of std::make_unique<Fusion>(...) is the correct approach here — make_unique cannot access protected constructors even through friendship — but nothing in the code explains this to future maintainers. A one-line comment would prevent a well-intentioned "fix" that breaks compilation:

Suggested change
auto fusion_segment =
std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr()));
// The shared-container constructor is protected; use raw new via friendship
// rather than std::make_unique (which cannot access protected ctors).
auto fusion_segment =
std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr()));


IrCloner complete_to_segment_map =
Fusion::copy(completeFusion(), fusion_segment.get());
Expand Down
2 changes: 2 additions & 0 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ int64_t IrContainer::numVals() const noexcept {
void IrContainer::addFusion(Fusion* fusion) {
std::unique_lock lock(mutex_);
sharing_fusions_.insert(fusion);
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration
per_fusion_exprs_[fusion];
Comment on lines +156 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-allocation immediately negated by Fusion::copy -> clear()

The comment claims this pre-allocation prevents outer-map rehash during concurrent val/expr registration. However, the call chain in makeFusion immediately undoes this:

new Fusion(container)  →  addFusion()  →  per_fusion_vals_[fusion_segment] = {}  // pre-allocated ✓
Fusion::copy(completeFusion(), fusion_segment.get())
  →  fusion_segment->clear()
     →  ir_container_->removeStatementsOwnedBy(fusion_segment)
        →  per_fusion_vals_.erase(vals_it)   // key REMOVED! pre-allocation lost

After removeStatementsOwnedBy erases the pre-allocated key, the very first registerVal call inside Fusion::copy executes per_fusion_vals_[fusion_segment].insert(val) with operator[] on a missing key, causing a new key insertion into the outer unordered_map. This new insertion can trigger a rehash, which is exactly what the pre-allocation was supposed to prevent.

The race window is: valsOwnedBy() acquires shared_lock, returns a const& into per_fusion_vals_[X], releases the lock, and the caller then calls std::ssize on that reference. Between the lock release and std::ssize, another thread can insert the fusion_segment key (causing rehash), invalidating the reference — UB.

A minimal fix that preserves the pre-allocation invariant is to clear the inner set in removeStatementsOwnedBy instead of erasing the outer key:

// In removeStatementsOwnedBy, instead of:
per_fusion_vals_.erase(vals_it);
// use:
vals_it->second.clear(); // keep the key; prevents re-insertion + rehash

This does require a separate cleanup in removeFusion (which already holds unique_lock) to erase the now-empty key after the fusion is fully removed from sharing_fusions_.

Comment on lines +156 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overly long inline comment

The inline comment on line 156 is very long (well past 80–100 characters), which will likely fail the project's clang-format or line-length lint checks:

per_fusion_vals_[fusion];   // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration

Consider breaking it into a preceding block comment:

Suggested change
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration
per_fusion_exprs_[fusion];
// Pre-insert keys so no outer-map rehash occurs during concurrent
// val/expr registration by segment Fusions sharing this container.
per_fusion_vals_[fusion];
per_fusion_exprs_[fusion];

}

void IrContainer::removeFusion(Fusion* fusion) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/statement_guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
return fusion;
}()),
prev_num_exprs_(fusion_->numExprs()),
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}
prev_num_vals_(fusion_->numVals()) {}

StatementGuard::~StatementGuard() {
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);
Expand Down
84 changes: 84 additions & 0 deletions tests/cpp/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,4 +1034,88 @@ TEST_F(SegmentationTest, ReshapeWithCrossSegmentExtent) {
executor_cache.fusion(), outputs, {t_a, t_b}, __LINE__, __FILE__);
}

// Stress test: 8 segments sharing one IrContainer via segmenter container
// sharing. Exercises parallel compilation with 8 concurrent threads all
// registering vals/exprs into the same shared container.
TEST_F(SegmentationTest, SharedContainerStress8Segments) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// Build a linear chain with 7 segment_set boundaries → 8 segments.
// Each segment does a different pointwise op to keep them distinct.
auto tv0 = makeContigConcreteTensor({1024, 256}, DataType::Float);
fusion->addInput(tv0);

auto tv = relu(tv0);
tv = segment_set(tv);
tv = neg(tv);
tv = segment_set(tv);
tv = sin(tv);
tv = segment_set(tv);
tv = relu(tv);
tv = segment_set(tv);
tv = neg(tv);
tv = segment_set(tv);
tv = sin(tv);
tv = segment_set(tv);
tv = relu(tv);
tv = segment_set(tv);
tv = neg(tv);
fusion->addOutput(tv);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024, 256}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto outputs = executor_cache.runFusionWithInputs({t0});

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(8));

testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
}

// Stress test: 12 parallel branches each producing a segment via independent
// reductions. Unlike the linear chain above, this creates 12 segments that
// can all compile simultaneously, maximizing concurrent lock contention on the
// shared IrContainer.
TEST_F(SegmentationTest, SharedContainerStress12ParallelBranches) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// 4 inputs, each feeds 3 independent reductions on different axes.
// Reductions on different axes cannot be merged → separate segments.
std::vector<TensorView*> inputs;
for (int i = 0; i < 4; i++) {
auto tv = makeContigConcreteTensor({64, 128, 32}, DataType::Float);
fusion->addInput(tv);
inputs.push_back(tv);
}

for (int i = 0; i < 4; i++) {
// Each input → 3 reductions on axes 0, 1, 2
for (int axis = 0; axis < 3; axis++) {
auto r = sum(inputs[i], {axis});
fusion->addOutput(r);
}
}

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
std::vector<c10::IValue> aten_inputs;
for (int i = 0; i < 4; i++) {
aten_inputs.push_back(at::randn({64, 128, 32}, options));
}

FusionExecutorCache executor_cache(std::move(fusion));
Comment on lines +1107 to +1109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXPECT_GE(..., 6) assertion may be too strong and cause a false test failure

The comment itself acknowledges uncertainty: "the segmenter may merge some compatible reductions." With 4 inputs × 3 axes, the segmenter can legally merge all axis-0 reductions (from all 4 inputs) into one segment, all axis-1 reductions into another, and all axis-2 reductions into a third — producing only 3 segments, not 6.

If the segmenter is ever improved to merge reductions across inputs (a valid and correct optimisation), this test will EXPECT_GE(3, 6) and fail, even though the shared-container infrastructure is working perfectly. The stress test's correctness goal is already validated by testValidate; the segment-count assertion adds fragility without adding safety.

Consider removing the segment-count assertion entirely, or replacing it with just EXPECT_GT(runtime->fusionSegments()->groups().size(), 1) to verify at least some segmentation occurred (which is what the test is designed to exercise):

Suggested change
}
FusionExecutorCache executor_cache(std::move(fusion));
EXPECT_GT(runtime->fusionSegments()->groups().size(), 1u);

auto outputs = executor_cache.runFusionWithInputs(aten_inputs);

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
// Expect at least 6 segments (the segmenter may merge some compatible
// reductions, but incompatible reduction axes force separate segments)
EXPECT_GE(runtime->fusionSegments()->groups().size(), 6);

testValidate(
executor_cache.fusion(), outputs, aten_inputs, __LINE__, __FILE__);
}

} // namespace nvfuser
Loading