-
Notifications
You must be signed in to change notification settings - Fork 79
[IR Container] Phase 2.7 Segmenter Container Sharing #5983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: md/phase2-thread-safety
Are you sure you want to change the base?
Changes from all commits
216e0d8
0c32f22
4da0378
3d03155
deb0a38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -181,18 +181,19 @@ struct Fusion::ContainerMutator { | |
| } | ||
| } | ||
|
|
||
| static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept { | ||
| auto* c = self->ir_container(); | ||
| // Use direct field access. Avoids re-entering valsOwnedBy() which acquires | ||
| // shared_lock. | ||
| const auto it = c->per_fusion_vals_.find(self); | ||
| int64_t count = it != c->per_fusion_vals_.end() | ||
| ? static_cast<int64_t>(it->second.size()) | ||
| : 0; | ||
| count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) + | ||
| (self->true_val_ != nullptr) + (self->false_val_ != nullptr) + | ||
| (self->magic_zero_val_ != nullptr); | ||
| return count; | ||
| // Null out self's shortcut-val pointer cache if v is one of them. | ||
| static void nullOutShortcutIfNeeded(Fusion* self, Val* v) { | ||
| if (v == self->zero_val_) { | ||
| self->zero_val_ = nullptr; | ||
| } else if (v == self->one_val_) { | ||
| self->one_val_ = nullptr; | ||
| } else if (v == self->true_val_) { | ||
| self->true_val_ = nullptr; | ||
| } else if (v == self->false_val_) { | ||
| self->false_val_ = nullptr; | ||
| } else if (v == self->magic_zero_val_) { | ||
| self->magic_zero_val_ = nullptr; | ||
| } | ||
| } | ||
|
|
||
| static void removeStatementsCreatedAfter( | ||
|
|
@@ -201,42 +202,84 @@ struct Fusion::ContainerMutator { | |
| int64_t num_vals_before) { | ||
| auto* c = self->ir_container(); | ||
|
|
||
| // Remove expressions before values because we need to change Val::uses_. | ||
| while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { | ||
| // Pop from global deque back — statements created by this Fusion during | ||
| // the guard scope are at the tail (LIFO invariant). | ||
| Expr* e = c->exprs_up_.back().get(); | ||
| NVF_ERROR( | ||
| c->per_fusion_exprs_[self].count(e) > 0, | ||
| "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); | ||
| for (Val* in : e->inputs()) { | ||
| in->removeUse(e); | ||
| // Use direct field access — hasMultipleFusions() acquires shared_lock which | ||
| // deadlocks since the caller already holds unique_lock on mutex_. | ||
| if (c->sharing_fusions_.size() <= 1) { | ||
| // Fast path: single Fusion owns this container, so the LIFO invariant | ||
| // holds — self's newest statements are always at the global deque tail. | ||
| // Remove expressions before values because we need to change Val::uses_. | ||
| while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { | ||
| Expr* e = c->exprs_up_.back().get(); | ||
| NVF_ERROR( | ||
| c->per_fusion_exprs_[self].count(e) > 0, | ||
| "removeStatementsCreatedAfter: tail expr belongs to another " | ||
| "Fusion"); | ||
| for (Val* out : e->outputs()) { | ||
| out->setDefinition(nullptr); | ||
| } | ||
| for (Val* in : e->inputs()) { | ||
| in->removeUse(e); | ||
| } | ||
| c->per_fusion_exprs_[self].erase(e); | ||
| c->exprs_.erase(e); | ||
| c->exprs_up_.pop_back(); | ||
| } | ||
| c->per_fusion_exprs_[self].erase(e); | ||
| c->exprs_.erase(e); | ||
| c->exprs_up_.pop_back(); | ||
| } | ||
|
|
||
| while (numValsExcludingShortcuts(self) > num_vals_before) { | ||
| Val* v = c->vals_up_.back().get(); | ||
| NVF_ERROR( | ||
| c->per_fusion_vals_[self].count(v) > 0, | ||
| "removeStatementsCreatedAfter: tail val belongs to another Fusion"); | ||
| // Null out shortcut caches if they point to vals about to be destroyed | ||
| if (v == self->zero_val_) { | ||
| self->zero_val_ = nullptr; | ||
| } else if (v == self->one_val_) { | ||
| self->one_val_ = nullptr; | ||
| } else if (v == self->true_val_) { | ||
| self->true_val_ = nullptr; | ||
| } else if (v == self->false_val_) { | ||
| self->false_val_ = nullptr; | ||
| } else if (v == self->magic_zero_val_) { | ||
| self->magic_zero_val_ = nullptr; | ||
| while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) { | ||
| Val* v = c->vals_up_.back().get(); | ||
| NVF_ERROR( | ||
| c->per_fusion_vals_[self].count(v) > 0, | ||
| "removeStatementsCreatedAfter: tail val belongs to another Fusion"); | ||
| nullOutShortcutIfNeeded(self, v); | ||
| c->per_fusion_vals_[self].erase(v); | ||
| c->vals_.erase(v); | ||
| c->vals_up_.pop_back(); | ||
| } | ||
| c->per_fusion_vals_[self].erase(v); | ||
| c->vals_.erase(v); | ||
| c->vals_up_.pop_back(); | ||
| } else { | ||
| // Slow path: shared container — other Fusions' statements may be | ||
| // interleaved at the tail of the global deques. Use std::erase_if | ||
| // (C++20) to scan forward: skip the first num_before of self's | ||
| // statements (old, to keep), then erase the remainder (added during | ||
| // the guard scope). Entered whenever the container is shared, | ||
| // regardless of success or failure; if no new statements were added | ||
| // the scan completes trivially. O(total statements in container). | ||
| int64_t exprs_kept = 0; | ||
| std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) { | ||
| Expr* e = e_up.get(); | ||
| if (c->per_fusion_exprs_[self].count(e) == 0) { | ||
| return false; // belongs to another Fusion — keep | ||
| } | ||
| if (exprs_kept < num_exprs_before) { | ||
| ++exprs_kept; | ||
| return false; // self's old expr — keep | ||
| } | ||
| // self's new expr — remove (clean up uses and index maps first) | ||
| for (Val* out : e->outputs()) { | ||
| out->setDefinition(nullptr); | ||
| } | ||
| for (Val* in : e->inputs()) { | ||
| in->removeUse(e); | ||
| } | ||
| c->per_fusion_exprs_[self].erase(e); | ||
| c->exprs_.erase(e); | ||
| return true; | ||
| }); | ||
mdavis36 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| int64_t vals_kept = 0; | ||
| std::erase_if(c->vals_up_, [&](const std::unique_ptr<Val>& v_up) { | ||
| Val* v = v_up.get(); | ||
| if (c->per_fusion_vals_[self].count(v) == 0) { | ||
| return false; // belongs to another Fusion — keep | ||
|
Comment on lines
+248
to
+271
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Both slow-path predicates access the inner per-fusion sets via if (c->per_fusion_exprs_[self].count(e) == 0) { … }
// and later:
if (c->per_fusion_vals_[self].count(v) == 0) { … }
The behavior is still logically correct (count returns 0, predicate returns false, nothing is erased), but it silently re-populates a key that auto exprs_it = c->per_fusion_exprs_.find(self);
std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
Expr* e = e_up.get();
if (exprs_it == c->per_fusion_exprs_.end() ||
exprs_it->second.count(e) == 0) {
return false; // belongs to another Fusion — keep
}
…
});And similarly for the |
||
| } | ||
| if (vals_kept < num_vals_before) { | ||
| ++vals_kept; | ||
| return false; // self's old val — keep | ||
| } | ||
| // self's new val — remove (null shortcut cache pointer if applicable) | ||
| nullOutShortcutIfNeeded(self, v); | ||
| c->per_fusion_vals_[self].erase(v); | ||
| c->vals_.erase(v); | ||
| return true; | ||
| }); | ||
| } | ||
| } | ||
| }; | ||
|
|
@@ -520,6 +563,12 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) { | |
| ir_container_->addFusion(this); | ||
| } | ||
|
|
||
| // Shared-container constructor -- creates empty Fusion using existing container | ||
| Fusion::Fusion(std::shared_ptr<IrContainer> container) | ||
| : ir_container_(std::move(container)) { | ||
| ir_container_->addFusion(this); | ||
| } | ||
|
|
||
| // Copy constructor -- shares the source's container | ||
| Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) { | ||
| FUSER_PERF_SCOPE("Fusion copy"); | ||
|
|
@@ -620,10 +669,6 @@ void Fusion::removeStatementsCreatedAfter( | |
| this, num_exprs_before, num_vals_before); | ||
| } | ||
|
|
||
| int64_t Fusion::numValsExcludingShortcuts() const noexcept { | ||
| return ContainerMutator::numValsExcludingShortcuts(this); | ||
| } | ||
|
|
||
| void Fusion::addInput(Val* input) { | ||
| assertInContainer(input, "Cannot register input "); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1801,7 +1801,8 @@ std::pair<IrCloner, std::unique_ptr<Fusion>> SegmentedFusion::makeFusion( | |||||||||||||
| SegmentedGroup* sg) const { | ||||||||||||||
| // TODO Optimize cloning step by only copying values and expressions between | ||||||||||||||
| // the fusion segment's inputs and outputs. | ||||||||||||||
| auto fusion_segment = std::make_unique<Fusion>(); | ||||||||||||||
| auto fusion_segment = | ||||||||||||||
| std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr())); | ||||||||||||||
|
Comment on lines
+1804
to
+1805
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raw Using
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| IrCloner complete_to_segment_map = | ||||||||||||||
| Fusion::copy(completeFusion(), fusion_segment.get()); | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -153,6 +153,8 @@ int64_t IrContainer::numVals() const noexcept { | |||||||||||||
| void IrContainer::addFusion(Fusion* fusion) { | ||||||||||||||
| std::unique_lock lock(mutex_); | ||||||||||||||
| sharing_fusions_.insert(fusion); | ||||||||||||||
| per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration | ||||||||||||||
| per_fusion_exprs_[fusion]; | ||||||||||||||
|
Comment on lines
+156
to
+157
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pre-allocation immediately negated by The comment claims this pre-allocation prevents outer-map rehash during concurrent val/expr registration. However, the call chain in After The race window is: A minimal fix that preserves the pre-allocation invariant is to clear the inner set in // In removeStatementsOwnedBy, instead of:
per_fusion_vals_.erase(vals_it);
// use:
vals_it->second.clear(); // keep the key; prevents re-insertion + rehashThis does require a separate cleanup in
Comment on lines
+156
to
+157
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overly long inline comment The inline comment on line 156 is very long (well past 80–100 characters), which will likely fail the project's clang-format or line-length lint checks: per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registrationConsider breaking it into a preceding block comment:
Suggested change
|
||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| void IrContainer::removeFusion(Fusion* fusion) { | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1034,4 +1034,88 @@ TEST_F(SegmentationTest, ReshapeWithCrossSegmentExtent) { | |||||||||
| executor_cache.fusion(), outputs, {t_a, t_b}, __LINE__, __FILE__); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Stress test: 8 segments sharing one IrContainer via segmenter container | ||||||||||
| // sharing. Exercises parallel compilation with 8 concurrent threads all | ||||||||||
| // registering vals/exprs into the same shared container. | ||||||||||
| TEST_F(SegmentationTest, SharedContainerStress8Segments) { | ||||||||||
| auto fusion = std::make_unique<Fusion>(); | ||||||||||
| FusionGuard fg(fusion.get()); | ||||||||||
|
|
||||||||||
| // Build a linear chain with 7 segment_set boundaries → 8 segments. | ||||||||||
| // Each segment does a different pointwise op to keep them distinct. | ||||||||||
| auto tv0 = makeContigConcreteTensor({1024, 256}, DataType::Float); | ||||||||||
| fusion->addInput(tv0); | ||||||||||
|
|
||||||||||
| auto tv = relu(tv0); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = neg(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = sin(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = relu(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = neg(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = sin(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = relu(tv); | ||||||||||
| tv = segment_set(tv); | ||||||||||
| tv = neg(tv); | ||||||||||
| fusion->addOutput(tv); | ||||||||||
|
|
||||||||||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||||||||||
| at::Tensor t0 = at::randn({1024, 256}, options); | ||||||||||
|
|
||||||||||
| FusionExecutorCache executor_cache(std::move(fusion)); | ||||||||||
| auto outputs = executor_cache.runFusionWithInputs({t0}); | ||||||||||
|
|
||||||||||
| FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); | ||||||||||
| EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(8)); | ||||||||||
|
|
||||||||||
| testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Stress test: 12 parallel branches each producing a segment via independent | ||||||||||
| // reductions. Unlike the linear chain above, this creates 12 segments that | ||||||||||
| // can all compile simultaneously, maximizing concurrent lock contention on the | ||||||||||
| // shared IrContainer. | ||||||||||
| TEST_F(SegmentationTest, SharedContainerStress12ParallelBranches) { | ||||||||||
| auto fusion = std::make_unique<Fusion>(); | ||||||||||
| FusionGuard fg(fusion.get()); | ||||||||||
|
|
||||||||||
| // 4 inputs, each feeds 3 independent reductions on different axes. | ||||||||||
| // Reductions on different axes cannot be merged → separate segments. | ||||||||||
| std::vector<TensorView*> inputs; | ||||||||||
| for (int i = 0; i < 4; i++) { | ||||||||||
| auto tv = makeContigConcreteTensor({64, 128, 32}, DataType::Float); | ||||||||||
| fusion->addInput(tv); | ||||||||||
| inputs.push_back(tv); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| for (int i = 0; i < 4; i++) { | ||||||||||
| // Each input → 3 reductions on axes 0, 1, 2 | ||||||||||
| for (int axis = 0; axis < 3; axis++) { | ||||||||||
| auto r = sum(inputs[i], {axis}); | ||||||||||
| fusion->addOutput(r); | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||||||||||
| std::vector<c10::IValue> aten_inputs; | ||||||||||
| for (int i = 0; i < 4; i++) { | ||||||||||
| aten_inputs.push_back(at::randn({64, 128, 32}, options)); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| FusionExecutorCache executor_cache(std::move(fusion)); | ||||||||||
|
Comment on lines
+1107
to
+1109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The comment itself acknowledges uncertainty: "the segmenter may merge some compatible reductions." With 4 inputs × 3 axes, the segmenter can legally merge all axis-0 reductions (from all 4 inputs) into one segment, all axis-1 reductions into another, and all axis-2 reductions into a third — producing only 3 segments, not 6. If the segmenter is ever improved to merge reductions across inputs (a valid and correct optimisation), this test will Consider removing the segment-count assertion entirely, or replacing it with just
Suggested change
|
||||||||||
| auto outputs = executor_cache.runFusionWithInputs(aten_inputs); | ||||||||||
|
|
||||||||||
| FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); | ||||||||||
| // Expect at least 6 segments (the segmenter may merge some compatible | ||||||||||
| // reductions, but incompatible reduction axes force separate segments) | ||||||||||
| EXPECT_GE(runtime->fusionSegments()->groups().size(), 6); | ||||||||||
|
|
||||||||||
| testValidate( | ||||||||||
| executor_cache.fusion(), outputs, aten_inputs, __LINE__, __FILE__); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| } // namespace nvfuser | ||||||||||
Uh oh!
There was an error while loading. Please reload this page.