@@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator {
190190 bool should_reject = false ;
191191};
192192
193- bool hasSmallTransposeDimensions (
194- const std::unique_ptr<TransposeParams>& params) {
193+ bool hasSmallTransposeDimensions (const TransposeParams* params) {
195194 return !params->split_before_tiling .empty () ||
196195 !params->dims_merged_with_1 .empty () ||
197196 !params->dims_merged_with_2 .empty ();
@@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason(
579578 // 1. view op; and
580579 // 2. small transpose transformation
581580 // See note [Supporting small transpose dimensions]
582- if (hasSmallTransposeDimensions (params)) {
581+ if (hasSmallTransposeDimensions (params. get () )) {
583582 return " Small transpose dimensions and view op cannot be currently be "
584583 " handled by transpose scheduler. See: "
585584 " https://github.com/NVIDIA/Fuser/pull/592" ;
@@ -677,7 +676,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
677676 inner_most_pos2_in_ref1);
678677
679678 NVF_ERROR (
680- !hasSmallTransposeDimensions (tparams) ||
679+ !hasSmallTransposeDimensions (tparams. get () ) ||
681680 scheduler_utils::getViewTVs (fusion).empty (),
682681 " combination of view op with small transpose dimensions are not "
683682 " supported by transpose scheduler" );
@@ -867,7 +866,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
867866 << " reference2: " << reference2->toString () << " \n "
868867 << " inner_most_id2 position: " << inner_most_pos2_in_ref1
869868 << " (in reference 1)" << std::endl;
870- if (hasSmallTransposeDimensions (tparams)) {
869+ if (hasSmallTransposeDimensions (tparams. get () )) {
871870 debug () << " small transposed dim, needs virtual inner-most dim"
872871 << std::endl;
873872 }
@@ -953,11 +952,15 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
953952 auto new_cache = tv->cacheAfter ();
954953 new_cache->setMemoryType (MemoryType::Shared);
955954 group2_and_cached_inputs.emplace (new_cache);
956- smem_cached_input_tvs.push_back (existing_cache);
955+ if (!hasSmallTransposeDimensions (tparams)) {
956+ smem_cached_input_tvs.push_back (new_cache);
957+ }
957958 } else {
958959 existing_cache->setMemoryType (MemoryType::Shared);
959960 group2_and_cached_inputs.emplace (existing_cache);
960- smem_cached_input_tvs.push_back (existing_cache);
961+ if (!hasSmallTransposeDimensions (tparams)) {
962+ smem_cached_input_tvs.push_back (existing_cache);
963+ }
961964 }
962965 }
963966 }
0 commit comments