From 343651adf7659442ed02e12ccfba975e15605f32 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 May 2025 10:37:23 -0700 Subject: [PATCH 1/4] WIP --- csrc/id_model/id_model.cpp | 141 +++++++++++++++++- csrc/id_model/id_model.h | 2 + .../scheduler/tools/loop_domain_scheduler.cpp | 26 ++-- csrc/scheduler/tools/loop_domain_scheduler.h | 3 +- 4 files changed, 160 insertions(+), 12 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 48b83347fc4..74dece5a66f 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -12,7 +12,6 @@ #include #include "device_lower/analysis/circular_buffer.h" -#include "device_lower/analysis/trivial_broadcast.h" #include "device_lower/lower2device.h" #include "device_lower/utils.h" #include "disjoint_set.h" @@ -25,7 +24,6 @@ #include "iter_visitor.h" #include "logical_domain_map.h" #include "transform_iter.h" -#include "val_graph_visitor.h" namespace nvfuser { @@ -1426,4 +1424,143 @@ ValGraph buildPermissiveResizeGraph(const ValGraph& permissive_graph) { return resize_graph; } +// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations +ValGraph mapAlmostExactSplits(const ValGraph& graph) { + auto new_graph = graph; + + // vg: I0 + auto get_l1r2_splits = + [&new_graph]( + const ValGroup& vg) -> std::vector> { + std::vector> l1_r2_splits; + + if (!new_graph.hasUses(vg)) { + return {}; + } + + for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) { + auto split_of_vg = dynamic_cast(use_of_vg->front()); + if (split_of_vg == nullptr) { + continue; + } + + // mn + const ValGroup& inner_group = new_graph.toGroup(split_of_vg->inner()); + + if (!new_graph.hasUses(inner_group)) { + return {}; + } + + for (const ExprGroup& use_of_inner_group : + new_graph.getUses(inner_group)) { + auto split_of_inner_group = + dynamic_cast(use_of_inner_group->front()); + if (split_of_inner_group == nullptr) { + continue; + } + + // This split needs to be divisible + auto extent = split_of_inner_group->in()->extent(); + auto factor = split_of_inner_group->factor(); + if (extent->isConstScalar() && factor->isConstScalar() && + (extent->evaluate().as() % + factor->evaluate().as() != + 0)) { + continue; + } + + l1_r2_splits.emplace_back(use_of_vg, use_of_inner_group); + + std::cerr << "L1R2 found: " << split_of_vg->toString() + << split_of_inner_group->toString(); + } + } + + return l1_r2_splits; + }; + + auto get_matching_l2r1_splits = + [&new_graph]( + const ValGroup& vg, const std::pair& l1_r2) + -> std::optional> { + auto m = l1_r2.second->front()->as()->outer()->extent(); + auto n = l1_r2.second->front()->as()->inner()->extent(); + + for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) { + auto split_of_vg = dynamic_cast(use_of_vg->front()); + if (split_of_vg == nullptr) { + continue; + } + + if (!split_of_vg->inner()->extent()->sameAs(n)) { + continue; + } + + // I0/n + const ValGroup& outer_group = new_graph.toGroup(split_of_vg->outer()); + + if (!new_graph.hasUses(outer_group)) { + return {}; + } + + for (const ExprGroup& use_of_outer_group : + new_graph.getUses(outer_group)) { + auto split_of_outer_group = + dynamic_cast(use_of_outer_group->front()); + if (split_of_outer_group == nullptr) { + continue; + } + + if (!split_of_outer_group->inner()->extent()->sameAs(m)) { + continue; + } + + std::cerr << "Matching L2R1 found: " << split_of_vg->toString() + << split_of_outer_group->toString(); + return std::make_pair(use_of_vg, use_of_outer_group); + } + } + + return std::nullopt; + }; + + std::vector> groups_to_map; + + for (const ValGroup& vg : new_graph.disjointValSets().disjointSets()) { + const auto all_l1r2_splits = get_l1r2_splits(vg); + for (const auto& l1r2 : all_l1r2_splits) { + std::cerr << "L1R2: " << l1r2.first->front()->toString() + << l1r2.second->front()->toString(); + auto l2r1 = get_matching_l2r1_splits(vg, l1r2); + if (!l2r1.has_value()) { + continue; + } + + std::cerr << "Found\n"; + + auto l1r2_first_outputs = new_graph.outputGroups(l1r2.first); + auto l1r2_second_outputs = new_graph.outputGroups(l1r2.second); + + auto l2r1_first_outputs = new_graph.outputGroups(l2r1->first); + auto l2r1_second_outputs = new_graph.outputGroups(l2r1->second); + + groups_to_map.emplace_back( + l1r2_first_outputs.at(0), l2r1_second_outputs.at(0)); + groups_to_map.emplace_back( + l1r2_second_outputs.at(0), l2r1_second_outputs.at(1)); + groups_to_map.emplace_back( + l1r2_second_outputs.at(1), l2r1_first_outputs.at(1)); + } + } + + for (const auto& [vg1, vg2] : groups_to_map) { + std::cerr << "Mapping " << nvfuser::toString(vg1) << ", " + << vg1->front()->toString() << " and " << nvfuser::toString(vg2) + << ", " << vg2->front()->toString() << "\n"; + new_graph.mapVals(vg1->front(), vg2->front()); + } + + return new_graph; +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index c6d26836751..6945f3c70a0 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -375,4 +375,6 @@ std::unordered_map updateValGroupIdMap( // This adds additional mappings for resize operations. ValGraph buildPermissiveResizeGraph(const ValGraph& permissive_graph); +ValGraph mapAlmostExactSplits(const ValGraph& graph); + } // namespace nvfuser diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 6f76b27e6bc..a6fa28e9a4f 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -178,14 +178,19 @@ class LoopDomainScheduler { public: LoopDomainScheduler( std::vector ref_loop_dom, - bool update_loop_domain_only = false) + bool update_loop_domain_only = false, + const ValGraph* scheduling_graph = nullptr) : ref_loop_dom_(std::move(ref_loop_dom)), - update_loop_domain_only_(update_loop_domain_only) { + update_loop_domain_only_(update_loop_domain_only), + graph_(scheduling_graph) { NVF_ERROR(!ref_loop_dom_.empty()); - Fusion* fusion = ref_loop_dom_.front()->fusion(); - id_model_ = std::make_unique(fusion, /*build_graphs=*/false); - id_model_->buildExactGraph(); + if (graph_ == nullptr) { + Fusion* fusion = ref_loop_dom_.front()->fusion(); + id_model_ = std::make_unique(fusion, /*build_graphs=*/false); + id_model_->buildExactGraph(); + graph_ = &(id_model_->idGraph(IdMappingMode::EXACT)); + } ref_id_groups_ = graph().toGroups(ref_loop_dom_); @@ -203,8 +208,9 @@ class LoopDomainScheduler { void schedule(TensorView* tv) const; private: - ValGraph& graph() const { - return id_model_->idGraph(IdMappingMode::EXACT); + const ValGraph& graph() const { + NVF_ERROR(graph_ != nullptr); + return *graph_; } ValGraphBFS::ExprPath getReplayPath(TensorView* tv) const; @@ -248,6 +254,7 @@ class LoopDomainScheduler { // updates it to make it look like the given reference loop domain bool update_loop_domain_only_ = false; std::unique_ptr id_model_; + const ValGraph* graph_ = nullptr; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; }; @@ -477,12 +484,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only) { + bool update_loop_domain_only, + const ValGraph* graph) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only, graph); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning, diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index d132ac08a1c..92c2cadac67 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -31,7 +31,8 @@ namespace scheduler_tools { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only = false); + bool update_loop_domain_only = false, + const ValGraph* graph = nullptr); // Replay a series of transform exprs on the loop domain of each of the given // tensors. If the replay direction is specified, the exprs are replayed From 0b4c417073df6a9709c9a348ba19d65dbc7c4434 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 May 2025 10:41:33 -0700 Subject: [PATCH 2/4] tests --- tests/cpp/test_id_model.cpp | 146 +++++++++++++++++++++++++++++++++--- 1 file changed, 136 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 5a950502810..da6ef12199b 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -6,8 +6,6 @@ */ // clang-format on -#include - #include #include @@ -17,9 +15,9 @@ #include "id_model/loop_promotion.h" #include "id_model/schedule.h" #include "id_model/to_string.h" -#include "ir/graphviz.h" #include "ops/all_ops.h" #include "scheduler/tools/inlining.h" +#include "scheduler/tools/loop_domain_scheduler.h" #include "scheduler/tools/resize_utils.h" #include "tests/cpp/utils.h" #include "transform_iter.h" @@ -235,8 +233,7 @@ void validateIELResolution( auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) - << "Unexpected promotion. " - << "Expected: " << ref_id->toString() + << "Unexpected promotion. " << "Expected: " << ref_id->toString() << ". Actual: " << promotion_id->toString(); ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id)) << "Promotion of " << id->toString() @@ -376,9 +373,9 @@ void checkStep4Results( const auto& iel_promotion_map = tester.s4_iel_promotion_map; EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size()) - << "Mismatched Step-4 result map. " - << "Expected to have " << ref_promotion_map.size() - << " mappings but found " << iel_promotion_map.size(); + << "Mismatched Step-4 result map. " << "Expected to have " + << ref_promotion_map.size() << " mappings but found " + << iel_promotion_map.size(); for (const auto& ref_promotion_pair : ref_promotion_map) { const auto& ref_promotion_group = ref_promotion_pair.first; @@ -3131,8 +3128,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { // Scatter output uses unique mapping schemes TEST_F(IdModelTest, ScatterLoopMapping) { auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); fusion.addInput(tv0); @@ -3337,4 +3334,133 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) { IdModel id_model(&fusion, true); } +TEST_F(IdModelTest, AlmostExactSplitGraph1) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph2) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + tv0->merge(1, 2); + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph3) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + tv0->merge(1, 2); + + tv1->split(0, 5); + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + } // namespace nvfuser From 93841af3905072417e8e85095b64394c8f429300 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 May 2025 13:51:18 -0700 Subject: [PATCH 3/4] another example --- tests/cpp/test_id_model.cpp | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index da6ef12199b..0a7f14ebf67 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -3463,4 +3463,53 @@ TEST_F(IdModelTest, AlmostExactSplitGraph3) { fusion.print(); } +TEST_F(IdModelTest, AlmostExactSplitGraph4) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({6, 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {6, 5}, {30}); + // Merge 6, 5 -> 30 + + auto tv3 = set(tv2); + + fusion.addOutput(tv3); + + tv0->outer_split(0, 2); + // [2, 3, 5] + + tv2->outer_split(0, 2); + // [2, 15] + tv2->outer_split(1, 3); + // [2, 3, 5] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT); + + for (const auto i: arange(tv0->nDims())) { + graph.mapVals(tv0->axis(i), tv2->axis(i)); + } + + std::cerr << graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2, tv3}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &graph); + + fusion.print(); +} + + } // namespace nvfuser From a1101099ec52e41a044813abb6e0b27fb220b47e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 May 2025 17:04:13 -0700 Subject: [PATCH 4/4] example --- tests/cpp/test_id_model.cpp | 52 ++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 0a7f14ebf67..20fa9595dfa 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -3496,7 +3496,7 @@ TEST_F(IdModelTest, AlmostExactSplitGraph4) { auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT); - for (const auto i: arange(tv0->nDims())) { + for (const auto i : arange(tv0->nDims())) { graph.mapVals(tv0->axis(i), tv2->axis(i)); } @@ -3511,5 +3511,55 @@ TEST_F(IdModelTest, AlmostExactSplitGraph4) { fusion.print(); } +TEST_F(IdModelTest, AlmostExactSplitGraph5) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + int64_t h = 60; + int64_t a = 6; + int64_t d = 2; + + auto tv0 = makeContigConcreteTensor({h}); + fusion.addInput(tv0); + + auto tv1 = reshape(tv0, {h}, {a, h / a}); + auto tv2 = set(tv1); + auto tv3 = reshape(tv2, {a, h / a}, {h}); + + fusion.addOutput(tv3); + + tv0->outer_split(0, d); + // [d, h/d] + tv0->split(1, h / a); + // [d, a/d, h/a] + + tv1->outer_split(0, d); + // [d, a/d, h/a] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT); + + for (const auto i : arange(tv0->nDims())) { + graph.mapVals(tv0->axis(i), tv1->axis(i)); + } + + graph.mapVals(tv0->getLogicalDomain().at(0), tv3->getLogicalDomain().at(0)); + + std::cerr << graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2, tv3}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &graph); + + fusion.print(); +} } // namespace nvfuser