From 71edbbdf3c6e129a600e4d4c58bdd9c2242e4950 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 May 2025 10:37:23 -0700 Subject: [PATCH 1/6] WIP --- csrc/id_model/id_model.cpp | 60 ++++++++++++++++++- .../scheduler/tools/loop_domain_scheduler.cpp | 26 +++++--- csrc/scheduler/tools/loop_domain_scheduler.h | 3 +- tests/cpp/test_id_model.cpp | 57 ++++++++++++++---- 4 files changed, 124 insertions(+), 22 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 48b83347fc4..5ee7be9c1c7 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -12,7 +12,6 @@ #include #include "device_lower/analysis/circular_buffer.h" -#include "device_lower/analysis/trivial_broadcast.h" #include "device_lower/lower2device.h" #include "device_lower/utils.h" #include "disjoint_set.h" @@ -25,7 +24,6 @@ #include "iter_visitor.h" #include "logical_domain_map.h" #include "transform_iter.h" -#include "val_graph_visitor.h" namespace nvfuser { @@ -495,6 +493,62 @@ std::vector> getTriviallyMappedIds(Expr* expr) { return mapped_ids; } +void mapSplitOfSplit(ValGraph& graph) { + // The following is a subpattern of + // https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations + // + // outer, _ = split(root) + // outermost_grand, _ = split(outer) + // outer', _ = split(root) + // + // If outermost_grand and outer' have the same extent, map them. + for (const ValGroup& root : graph.disjointValSets().disjointSets()) { + if (!graph.hasUses(root)) { + continue; + } + const ExprGroups& uses_of_root = graph.getUses(root); + std::vector outermost_grands; + for (const ExprGroup& use_of_root : uses_of_root) { + auto* split0 = dynamic_cast(use_of_root->front()); + if (split0 == nullptr) { + continue; + } + // Only follow the outer output of the first split; outer and inner + // must not be conflated. + const ValGroup& outer = graph.toGroup(split0->outer()); + if (!graph.hasUses(outer)) { + continue; + } + for (const ExprGroup& use_of_outer : graph.getUses(outer)) { + auto* split1 = dynamic_cast(use_of_outer->front()); + if (split1 == nullptr) { + continue; + } + const ValGroup& outermost_grand = graph.toGroup(split1->outer()); + outermost_grands.push_back(outermost_grand); + } + } + + for (const ValGroup& outermost_grand : outermost_grands) { + Val* extent_of_grand = + outermost_grand->front()->as()->extent(); + + for (const ExprGroup& use_of_root : uses_of_root) { + auto* split = dynamic_cast(use_of_root->front()); + if (split == nullptr) { + continue; + } + + const ValGroup& outer = graph.toGroup(split->outer()); + if (outer->front()->as()->extent()->sameAs( + extent_of_grand)) { + graph.mapVals(outermost_grand->front(), outer->front()); + } + } + } + } +} + } // namespace ValGraph& IdModel::buildAlmostExactGraph() { @@ -554,6 +608,8 @@ ValGraph& IdModel::buildAlmostExactGraph() { almost_exact_graph.mapVals(id1, id2); } + mapSplitOfSplit(almost_exact_graph); + almost_exact_graph.validateConsistency(); if (!allow_self_mapping_) { diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 6f76b27e6bc..a6fa28e9a4f 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -178,14 +178,19 @@ class LoopDomainScheduler { public: LoopDomainScheduler( std::vector ref_loop_dom, - bool update_loop_domain_only = false) + bool update_loop_domain_only = false, + const ValGraph* scheduling_graph = nullptr) : ref_loop_dom_(std::move(ref_loop_dom)), - update_loop_domain_only_(update_loop_domain_only) { + update_loop_domain_only_(update_loop_domain_only), + graph_(scheduling_graph) { NVF_ERROR(!ref_loop_dom_.empty()); - Fusion* fusion = ref_loop_dom_.front()->fusion(); - id_model_ = std::make_unique(fusion, /*build_graphs=*/false); - id_model_->buildExactGraph(); + if (graph_ == nullptr) { + Fusion* fusion = ref_loop_dom_.front()->fusion(); + id_model_ = std::make_unique(fusion, /*build_graphs=*/false); + id_model_->buildExactGraph(); + graph_ = &(id_model_->idGraph(IdMappingMode::EXACT)); + } ref_id_groups_ = graph().toGroups(ref_loop_dom_); @@ -203,8 +208,9 @@ class LoopDomainScheduler { void schedule(TensorView* tv) const; private: - ValGraph& graph() const { - return id_model_->idGraph(IdMappingMode::EXACT); + const ValGraph& graph() const { + NVF_ERROR(graph_ != nullptr); + return *graph_; } ValGraphBFS::ExprPath getReplayPath(TensorView* tv) const; @@ -248,6 +254,7 @@ class LoopDomainScheduler { // updates it to make it look like the given reference loop domain bool update_loop_domain_only_ = false; std::unique_ptr id_model_; + const ValGraph* graph_ = nullptr; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; }; @@ -477,12 +484,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only) { + bool update_loop_domain_only, + const ValGraph* graph) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only, graph); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning, diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index d132ac08a1c..92c2cadac67 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -31,7 +31,8 @@ namespace scheduler_tools { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only = false); + bool update_loop_domain_only = false, + const ValGraph* graph = nullptr); // Replay a series of transform exprs on the loop domain of each of the given // tensors. If the replay direction is specified, the exprs are replayed diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 5a950502810..ca2348ff5b5 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -6,8 +6,6 @@ */ // clang-format on -#include - #include #include @@ -17,9 +15,9 @@ #include "id_model/loop_promotion.h" #include "id_model/schedule.h" #include "id_model/to_string.h" -#include "ir/graphviz.h" #include "ops/all_ops.h" #include "scheduler/tools/inlining.h" +#include "scheduler/tools/loop_domain_scheduler.h" #include "scheduler/tools/resize_utils.h" #include "tests/cpp/utils.h" #include "transform_iter.h" @@ -235,8 +233,7 @@ void validateIELResolution( auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) - << "Unexpected promotion. " - << "Expected: " << ref_id->toString() + << "Unexpected promotion. " << "Expected: " << ref_id->toString() << ". Actual: " << promotion_id->toString(); ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id)) << "Promotion of " << id->toString() @@ -376,9 +373,9 @@ void checkStep4Results( const auto& iel_promotion_map = tester.s4_iel_promotion_map; EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size()) - << "Mismatched Step-4 result map. " - << "Expected to have " << ref_promotion_map.size() - << " mappings but found " << iel_promotion_map.size(); + << "Mismatched Step-4 result map. " << "Expected to have " + << ref_promotion_map.size() << " mappings but found " + << iel_promotion_map.size(); for (const auto& ref_promotion_pair : ref_promotion_map) { const auto& ref_promotion_group = ref_promotion_pair.first; @@ -3131,8 +3128,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { // Scatter output uses unique mapping schemes TEST_F(IdModelTest, ScatterLoopMapping) { auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); fusion.addInput(tv0); @@ -3337,4 +3334,44 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) { IdModel id_model(&fusion, true); } +TEST_F(IdModelTest, SplitingReshape) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2 * 2 * 2}); + fusion.addInput(in); + TensorView* out = reshape(in, {2 * 2 * 2}, {2 * 2, 2}); + fusion.addOutput(out); + + in->outer_split(0, 2); + out->outer_split(0, 2); + + IdModel id_model(&fusion); + const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph(); + EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(0))); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(1))); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(1), out->axis(2))); +} + +TEST_F(IdModelTest, SplitingReshape_DifferentExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({12}); + fusion.addInput(in); + TensorView* out = reshape(in, {12}, {6, 2}); + fusion.addOutput(out); + + in->outer_split(0, 2); + out->outer_split(0, 3); + + IdModel id_model(&fusion); + const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph(); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(0))); +} + } // namespace nvfuser From 97eeaafdf5e5f56b3d10f2c70ec545fe266deb32 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 19 Feb 2026 13:55:23 -0800 Subject: [PATCH 2/6] Fix build --- csrc/scheduler/tools/loop_domain_scheduler.h | 1 + tests/cpp/test_id_model.cpp | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 92c2cadac67..5d00e40a8d4 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -10,6 +10,7 @@ #include #include "bfs.h" +#include "val_graph.h" namespace nvfuser { diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index ca2348ff5b5..94bb7f569cd 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -17,7 +17,6 @@ #include "id_model/to_string.h" #include "ops/all_ops.h" #include "scheduler/tools/inlining.h" -#include "scheduler/tools/loop_domain_scheduler.h" #include "scheduler/tools/resize_utils.h" #include "tests/cpp/utils.h" #include "transform_iter.h" From f8c1090a4dc91e8c38fce99d4b7fe82e2e57b7a1 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 19 Feb 2026 14:18:06 -0800 Subject: [PATCH 3/6] Revert changes --- .../scheduler/tools/loop_domain_scheduler.cpp | 26 +++++++------------ csrc/scheduler/tools/loop_domain_scheduler.h | 4 +-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index a6fa28e9a4f..6f76b27e6bc 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -178,19 +178,14 @@ class LoopDomainScheduler { public: LoopDomainScheduler( std::vector ref_loop_dom, - bool update_loop_domain_only = false, - const ValGraph* scheduling_graph = nullptr) + bool update_loop_domain_only = false) : ref_loop_dom_(std::move(ref_loop_dom)), - update_loop_domain_only_(update_loop_domain_only), - graph_(scheduling_graph) { + update_loop_domain_only_(update_loop_domain_only) { NVF_ERROR(!ref_loop_dom_.empty()); - if (graph_ == nullptr) { - Fusion* fusion = ref_loop_dom_.front()->fusion(); - id_model_ = std::make_unique(fusion, /*build_graphs=*/false); - id_model_->buildExactGraph(); - graph_ = &(id_model_->idGraph(IdMappingMode::EXACT)); - } + Fusion* fusion = ref_loop_dom_.front()->fusion(); + id_model_ = std::make_unique(fusion, /*build_graphs=*/false); + id_model_->buildExactGraph(); ref_id_groups_ = graph().toGroups(ref_loop_dom_); @@ -208,9 +203,8 @@ class LoopDomainScheduler { void schedule(TensorView* tv) const; private: - const ValGraph& graph() const { - NVF_ERROR(graph_ != nullptr); - return *graph_; + ValGraph& graph() const { + return id_model_->idGraph(IdMappingMode::EXACT); } ValGraphBFS::ExprPath getReplayPath(TensorView* tv) const; @@ -254,7 +248,6 @@ class LoopDomainScheduler { // updates it to make it look like the given reference loop domain bool update_loop_domain_only_ = false; std::unique_ptr id_model_; - const ValGraph* graph_ = nullptr; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; }; @@ -484,13 +477,12 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only, - const ValGraph* graph) { + bool update_loop_domain_only) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only, graph); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning, diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 5d00e40a8d4..d132ac08a1c 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -10,7 +10,6 @@ #include #include "bfs.h" -#include "val_graph.h" namespace nvfuser { @@ -32,8 +31,7 @@ namespace scheduler_tools { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only = false, - const ValGraph* graph = nullptr); + bool update_loop_domain_only = false); // Replay a series of transform exprs on the loop domain of each of the given // tensors. If the replay direction is specified, the exprs are replayed From 21d03153c94becde200f57b79a18732045db6a68 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 19 Feb 2026 14:25:35 -0800 Subject: [PATCH 4/6] Some cleanups --- tests/cpp/test_id_model.cpp | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 94bb7f569cd..411e8a1f7d7 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -232,7 +232,7 @@ void validateIELResolution( auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) - << "Unexpected promotion. " << "Expected: " << ref_id->toString() + << "Unexpected promotion. Expected: " << ref_id->toString() << ". Actual: " << promotion_id->toString(); ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id)) << "Promotion of " << id->toString() @@ -372,7 +372,7 @@ void checkStep4Results( const auto& iel_promotion_map = tester.s4_iel_promotion_map; EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size()) - << "Mismatched Step-4 result map. " << "Expected to have " + << "Mismatched Step-4 result map. Expected to have " << ref_promotion_map.size() << " mappings but found " << iel_promotion_map.size(); @@ -2933,9 +2933,8 @@ TEST_F(IdModelTest, LoopPromotionCyclicGraphWar) { // Test to verify the split-aware covered group analysis. See // also https://github.com/NVIDIA/Fuser/pull/3877. TEST_F(IdModelTest, CoveredGroups) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion fusion; + FusionGuard fg(&fusion); auto tv0 = makeContigConcreteTensor({-1, 1}); fusion.addInput(tv0); @@ -2996,7 +2995,7 @@ TEST_F(IdModelTest, CoveredGroups) { TEST_F(IdModelTest, InvalidLoopPromotion) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + FusionGuard fg(&fusion); auto T0 = makeContigConcreteTensor({1, 32, 6}); fusion.addInput(T0); @@ -3082,9 +3081,8 @@ TEST_F(IdModelTest, InvalidLoopPromotion) { // When a loop group only includes broadcast IDs, the group should not // need to be promoted TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion fusion; + FusionGuard fg(&fusion); auto tv0 = makeContigConcreteTensor({-1, 1}); fusion.addInput(tv0); @@ -3126,8 +3124,7 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { // Scatter output uses unique mapping schemes TEST_F(IdModelTest, ScatterLoopMapping) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); @@ -3181,8 +3178,7 @@ TEST_F(IdModelTest, ScatterLoopMapping) { // required but is a WAR for special ops like // PreprocessGroupedMatmulInputSf. See also issue #5391. TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); @@ -3215,8 +3211,7 @@ TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) { } TEST_F(IdModelTest, PermissiveResizeGraph) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeConcreteTensor({36}); @@ -3266,8 +3261,7 @@ TEST_F(IdModelTest, PermissiveResizeGraph) { // This is the failing segment of the reproducer of // https://github.com/NVIDIA/Fuser/issues/5803. TEST_F(IdModelTest, ReproIssue5803) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv2 = makeContigConcreteTensor({4}, DataType::Int); @@ -3308,8 +3302,7 @@ TEST_F(IdModelTest, ReproIssue5803) { // This is a minimal fusion pattern to trigger the loop promotion // issue as reported in https://github.com/NVIDIA/Fuser/issues/5803 TEST_F(IdModelTest, ReproIssue5803Minimal) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeConcreteTensor({4, 8}); From e130ebee8e7a8e9e607871ba654a496343ce340c Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 19 Feb 2026 15:54:03 -0800 Subject: [PATCH 5/6] Map after traversal --- csrc/id_model/id_model.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 5ee7be9c1c7..182471ad92c 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -502,6 +502,7 @@ void mapSplitOfSplit(ValGraph& graph) { // outer', _ = split(root) // // If outermost_grand and outer' have the same extent, map them. + std::vector> ids_to_map; for (const ValGroup& root : graph.disjointValSets().disjointSets()) { if (!graph.hasUses(root)) { continue; @@ -542,11 +543,15 @@ void mapSplitOfSplit(ValGraph& graph) { const ValGroup& outer = graph.toGroup(split->outer()); if (outer->front()->as()->extent()->sameAs( extent_of_grand)) { - graph.mapVals(outermost_grand->front(), outer->front()); + ids_to_map.emplace_back(outermost_grand->front(), outer->front()); } } } } + + for (const auto& [id1, id2] : ids_to_map) { + graph.mapVals(id1, id2); + } } } // namespace From 42937d37be02b0a008f9f4b1dee02fc17f6f1d3b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 19 Feb 2026 16:19:55 -0800 Subject: [PATCH 6/6] Allow getUses to be called without having to check hasUses --- csrc/id_model/id_model.cpp | 6 ------ csrc/val_graph.cpp | 9 +++++---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 182471ad92c..73480ccf83b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -504,9 +504,6 @@ void mapSplitOfSplit(ValGraph& graph) { // If outermost_grand and outer' have the same extent, map them. std::vector> ids_to_map; for (const ValGroup& root : graph.disjointValSets().disjointSets()) { - if (!graph.hasUses(root)) { - continue; - } const ExprGroups& uses_of_root = graph.getUses(root); std::vector outermost_grands; for (const ExprGroup& use_of_root : uses_of_root) { @@ -517,9 +514,6 @@ void mapSplitOfSplit(ValGraph& graph) { // Only follow the outer output of the first split; outer and inner // must not be conflated. const ValGroup& outer = graph.toGroup(split0->outer()); - if (!graph.hasUses(outer)) { - continue; - } for (const ExprGroup& use_of_outer : graph.getUses(outer)) { auto* split1 = dynamic_cast(use_of_outer->front()); if (split1 == nullptr) { diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 9c7ea98528d..5fed001e940 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -397,11 +397,12 @@ const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const { const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const { NVF_ERROR(val_group, "Nullptr not allowed"); + + static ExprGroups empty_expr_groups; const auto it = unique_uses_.find(val_group); - NVF_ERROR( - it != unique_uses_.end(), - "Use group not found for ", - nvfuser::toString(val_group)); + if (it == unique_uses_.end()) { + return empty_expr_groups; + } return it->second; }