Skip to content

Commit 5bbb6fa

Browse files
committed
fix tests
1 parent ff62c9e commit 5bbb6fa

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

csrc/scheduler/transpose.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator {
190190
bool should_reject = false;
191191
};
192192

193-
bool hasSmallTransposeDimensions(
194-
const std::unique_ptr<TransposeParams>& params) {
193+
bool hasSmallTransposeDimensions(const TransposeParams* params) {
195194
return !params->split_before_tiling.empty() ||
196195
!params->dims_merged_with_1.empty() ||
197196
!params->dims_merged_with_2.empty();
@@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason(
579578
// 1. view op; and
580579
// 2. small transpose transformation
581580
// See note [Supporting small transpose dimensions]
582-
if (hasSmallTransposeDimensions(params)) {
581+
if (hasSmallTransposeDimensions(params.get())) {
583582
return "Small transpose dimensions and view op cannot be currently be "
584583
"handled by transpose scheduler. See: "
585584
"https://github.com/NVIDIA/Fuser/pull/592";
@@ -677,7 +676,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
677676
inner_most_pos2_in_ref1);
678677

679678
NVF_ERROR(
680-
!hasSmallTransposeDimensions(tparams) ||
679+
!hasSmallTransposeDimensions(tparams.get()) ||
681680
scheduler_utils::getViewTVs(fusion).empty(),
682681
"combination of view op with small transpose dimensions are not "
683682
"supported by transpose scheduler");
@@ -867,7 +866,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
867866
<< "reference2: " << reference2->toString() << "\n"
868867
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1
869868
<< " (in reference 1)" << std::endl;
870-
if (hasSmallTransposeDimensions(tparams)) {
869+
if (hasSmallTransposeDimensions(tparams.get())) {
871870
debug() << "small transposed dim, needs virtual inner-most dim"
872871
<< std::endl;
873872
}
@@ -953,11 +952,15 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
953952
auto new_cache = tv->cacheAfter();
954953
new_cache->setMemoryType(MemoryType::Shared);
955954
group2_and_cached_inputs.emplace(new_cache);
956-
smem_cached_input_tvs.push_back(existing_cache);
955+
if (!hasSmallTransposeDimensions(tparams)) {
956+
smem_cached_input_tvs.push_back(new_cache);
957+
}
957958
} else {
958959
existing_cache->setMemoryType(MemoryType::Shared);
959960
group2_and_cached_inputs.emplace(existing_cache);
960-
smem_cached_input_tvs.push_back(existing_cache);
961+
if (!hasSmallTransposeDimensions(tparams)) {
962+
smem_cached_input_tvs.push_back(existing_cache);
963+
}
961964
}
962965
}
963966
}

0 commit comments

Comments
 (0)