diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 48b83347fc4..73480ccf83b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -12,7 +12,6 @@ #include #include "device_lower/analysis/circular_buffer.h" -#include "device_lower/analysis/trivial_broadcast.h" #include "device_lower/lower2device.h" #include "device_lower/utils.h" #include "disjoint_set.h" @@ -25,7 +24,6 @@ #include "iter_visitor.h" #include "logical_domain_map.h" #include "transform_iter.h" -#include "val_graph_visitor.h" namespace nvfuser { @@ -495,6 +493,61 @@ std::vector> getTriviallyMappedIds(Expr* expr) { return mapped_ids; } +void mapSplitOfSplit(ValGraph& graph) { + // The following is a subpattern of + // https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations + // + // outer, _ = split(root) + // outermost_grand, _ = split(outer) + // outer', _ = split(root) + // + // If outermost_grand and outer' have the same extent, map them. + std::vector> ids_to_map; + for (const ValGroup& root : graph.disjointValSets().disjointSets()) { + const ExprGroups& uses_of_root = graph.getUses(root); + std::vector outermost_grands; + for (const ExprGroup& use_of_root : uses_of_root) { + auto* split0 = dynamic_cast(use_of_root->front()); + if (split0 == nullptr) { + continue; + } + // Only follow the outer output of the first split; outer and inner + // must not be conflated. + const ValGroup& outer = graph.toGroup(split0->outer()); + for (const ExprGroup& use_of_outer : graph.getUses(outer)) { + auto* split1 = dynamic_cast(use_of_outer->front()); + if (split1 == nullptr) { + continue; + } + const ValGroup& outermost_grand = graph.toGroup(split1->outer()); + outermost_grands.push_back(outermost_grand); + } + } + + for (const ValGroup& outermost_grand : outermost_grands) { + Val* extent_of_grand = + outermost_grand->front()->as()->extent(); + + for (const ExprGroup& use_of_root : uses_of_root) { + auto* split = dynamic_cast(use_of_root->front()); + if (split == nullptr) { + continue; + } + + const ValGroup& outer = graph.toGroup(split->outer()); + if (outer->front()->as()->extent()->sameAs( + extent_of_grand)) { + ids_to_map.emplace_back(outermost_grand->front(), outer->front()); + } + } + } + } + + for (const auto& [id1, id2] : ids_to_map) { + graph.mapVals(id1, id2); + } +} + } // namespace ValGraph& IdModel::buildAlmostExactGraph() { @@ -554,6 +607,8 @@ ValGraph& IdModel::buildAlmostExactGraph() { almost_exact_graph.mapVals(id1, id2); } + mapSplitOfSplit(almost_exact_graph); + almost_exact_graph.validateConsistency(); if (!allow_self_mapping_) { diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 9c7ea98528d..5fed001e940 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -397,11 +397,12 @@ const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const { const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const { NVF_ERROR(val_group, "Nullptr not allowed"); + + static ExprGroups empty_expr_groups; const auto it = unique_uses_.find(val_group); - NVF_ERROR( - it != unique_uses_.end(), - "Use group not found for ", - nvfuser::toString(val_group)); + if (it == unique_uses_.end()) { + return empty_expr_groups; + } return it->second; } diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 5a950502810..411e8a1f7d7 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -6,8 +6,6 @@ */ // clang-format on -#include - #include #include @@ -17,7 +15,6 @@ #include "id_model/loop_promotion.h" #include "id_model/schedule.h" #include "id_model/to_string.h" -#include "ir/graphviz.h" #include "ops/all_ops.h" #include "scheduler/tools/inlining.h" #include "scheduler/tools/resize_utils.h" @@ -235,8 +232,7 @@ void validateIELResolution( auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) - << "Unexpected promotion. " - << "Expected: " << ref_id->toString() + << "Unexpected promotion. Expected: " << ref_id->toString() << ". Actual: " << promotion_id->toString(); ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id)) << "Promotion of " << id->toString() @@ -376,9 +372,9 @@ void checkStep4Results( const auto& iel_promotion_map = tester.s4_iel_promotion_map; EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size()) - << "Mismatched Step-4 result map. " - << "Expected to have " << ref_promotion_map.size() - << " mappings but found " << iel_promotion_map.size(); + << "Mismatched Step-4 result map. Expected to have " + << ref_promotion_map.size() << " mappings but found " + << iel_promotion_map.size(); for (const auto& ref_promotion_pair : ref_promotion_map) { const auto& ref_promotion_group = ref_promotion_pair.first; @@ -2937,9 +2933,8 @@ TEST_F(IdModelTest, LoopPromotionCyclicGraphWar) { // Test to verify the split-aware covered group analysis. See // also https://github.com/NVIDIA/Fuser/pull/3877. TEST_F(IdModelTest, CoveredGroups) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion fusion; + FusionGuard fg(&fusion); auto tv0 = makeContigConcreteTensor({-1, 1}); fusion.addInput(tv0); @@ -3000,7 +2995,7 @@ TEST_F(IdModelTest, CoveredGroups) { TEST_F(IdModelTest, InvalidLoopPromotion) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + FusionGuard fg(&fusion); auto T0 = makeContigConcreteTensor({1, 32, 6}); fusion.addInput(T0); @@ -3086,9 +3081,8 @@ TEST_F(IdModelTest, InvalidLoopPromotion) { // When a loop group only includes broadcast IDs, the group should not // need to be promoted TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion fusion; + FusionGuard fg(&fusion); auto tv0 = makeContigConcreteTensor({-1, 1}); fusion.addInput(tv0); @@ -3130,9 +3124,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { // Scatter output uses unique mapping schemes TEST_F(IdModelTest, ScatterLoopMapping) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion fusion; + FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); fusion.addInput(tv0); @@ -3185,8 +3178,7 @@ TEST_F(IdModelTest, ScatterLoopMapping) { // required but is a WAR for special ops like // PreprocessGroupedMatmulInputSf. See also issue #5391. TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); @@ -3219,8 +3211,7 @@ TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) { } TEST_F(IdModelTest, PermissiveResizeGraph) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeConcreteTensor({36}); @@ -3270,8 +3261,7 @@ TEST_F(IdModelTest, PermissiveResizeGraph) { // This is the failing segment of the reproducer of // https://github.com/NVIDIA/Fuser/issues/5803. TEST_F(IdModelTest, ReproIssue5803) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv2 = makeContigConcreteTensor({4}, DataType::Int); @@ -3312,8 +3302,7 @@ TEST_F(IdModelTest, ReproIssue5803) { // This is a minimal fusion pattern to trigger the loop promotion // issue as reported in https://github.com/NVIDIA/Fuser/issues/5803 TEST_F(IdModelTest, ReproIssue5803Minimal) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeConcreteTensor({4, 8}); @@ -3337,4 +3326,44 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) { IdModel id_model(&fusion, true); } +TEST_F(IdModelTest, SplitingReshape) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2 * 2 * 2}); + fusion.addInput(in); + TensorView* out = reshape(in, {2 * 2 * 2}, {2 * 2, 2}); + fusion.addOutput(out); + + in->outer_split(0, 2); + out->outer_split(0, 2); + + IdModel id_model(&fusion); + const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph(); + EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(0))); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(1))); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(1), out->axis(2))); +} + +TEST_F(IdModelTest, SplitingReshape_DifferentExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({12}); + fusion.addInput(in); + TensorView* out = reshape(in, {12}, {6, 2}); + fusion.addOutput(out); + + in->outer_split(0, 2); + out->outer_split(0, 3); + + IdModel id_model(&fusion); + const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph(); + EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(0))); +} + } // namespace nvfuser