diff --git a/csrc/scheduler/normalization_inner.cpp b/csrc/scheduler/normalization_inner.cpp index e35a143eeb3..dc8c06c4827 100644 --- a/csrc/scheduler/normalization_inner.cpp +++ b/csrc/scheduler/normalization_inner.cpp @@ -33,6 +33,11 @@ bool mayUseTma(Fusion* fusion, const PersistentKernelProperties& prop) { return false; } + // TMA requires compile-time known contiguous innermost dimension on inputs + if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) { + return false; + } + // TMA requires 16-byte alignment (128 bits) for memory transactions if (prop.vectorize_factor * prop.max_dtype_size_bit % 128 != 0) { return false; diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 823af5f12f9..3b4c39281e0 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -32,6 +32,10 @@ bool preferWarpSpecialized( if (at::cuda::getCurrentDeviceProperties()->major < 10) { return false; } + // TMA requires compile-time known contiguous innermost dimension on inputs + if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) { + return false; + } // False, if any of the inputs is dynamically shaped // TODO: extend to support dynamic inputs, warp specialization requires // static CTA size diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index bb03ccaf64e..8072fd6f354 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -400,11 +400,18 @@ bool mayHaveTmaCompatibleInputs( // serves as a fast path to avoid computing full heuristics if TMA is clearly // not applicable. Passing this check does not guarantee that TMA will be used; // the final decision is made during heuristics computation. -bool mayUseTma(const pointwise_utils::FusionRuntimeProperties& prop) { +bool mayUseTma( + Fusion* fusion, + const pointwise_utils::FusionRuntimeProperties& prop) { // Hardware requirement: Don't use TMA for pre-Hopper GPUs if (at::cuda::getCurrentDeviceProperties()->major < 9) { return false; } + + // TMA requires compile-time known contiguous innermost dimension on inputs + if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) { + return false; + } // Check if there are TMA-compatible inputs if (!mayHaveTmaCompatibleInputs(prop)) { return false; @@ -430,7 +437,8 @@ std::unique_ptr PointWiseScheduler::computeHeuristics( } const auto& prop = prop_opt.value(); - bool use_tma = mayUseTma(prop) && isOptionEnabled(EnableOption::TmaPointwise); + bool use_tma = mayUseTma(fusion, prop) && + isOptionEnabled(EnableOption::TmaPointwise); std::unique_ptr pparams = nullptr; if (use_tma) { pparams = pointwise::tma::getPointwiseHeuristics( diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index dec667b819c..831b2ae80b3 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -196,6 +196,7 @@ bool ReductionScheduler::canScheduleRunTime( namespace { bool mayUseTma( + Fusion* fusion, const reduction_scheduler_utils::FusionRuntimeProperties& props) { auto dev_prop = at::cuda::getCurrentDeviceProperties(); @@ -203,6 +204,10 @@ bool mayUseTma( return false; } + if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) { + return false; + } + // Require the reduction shape is 2D inner reduction: [I, R] if (!props.fastest_dim_reduction || props.total_reduction_numel != props.inner_most_dimension_numel) { @@ -251,6 +256,7 @@ bool mayUseTma( } bool mayUseTmaOuter( + Fusion* fusion, const reduction_scheduler_utils::FusionRuntimeProperties& props) { auto dev_prop = at::cuda::getCurrentDeviceProperties(); @@ -258,6 +264,10 @@ bool mayUseTmaOuter( return false; } + if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) { + return false; + } + // Require outer reduction if (props.fastest_dim_reduction) { return false; @@ -308,13 +318,13 @@ std::unique_ptr ReductionScheduler::computeHeuristics( std::unique_ptr rparams = nullptr; // Try outer TMA scheduler for outer reductions - if (tma_enabled && mayUseTmaOuter(props)) { + if (tma_enabled && mayUseTmaOuter(fusion, props)) { rparams = reduction::outer_tma::getReductionHeuristics( fusion, runtime_info, data_cache, props); } // Try inner TMA scheduler for inner reductions - if (rparams == nullptr && tma_enabled && mayUseTma(props)) { + if (rparams == nullptr && tma_enabled && mayUseTma(fusion, props)) { rparams = reduction::tma::getReductionHeuristics( fusion, runtime_info, data_cache, props); } diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 541f2d1c358..abb2a69ef32 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -3767,4 +3767,30 @@ int64_t countLeadingParallelDimensions(const TensorView* tv) { return num_parallel_dims; } +bool inputsHaveContiguousInnerDim(Fusion* fusion) { + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + const auto& contig = tv->domain()->contiguity(); + if (contig.empty()) { + return false; + } + const auto& alloc_dom = tv->getMaybeAllocationDomain(); + NVF_ERROR(contig.size() == alloc_dom.size()); + bool found_inner = false; + for (int64_t i = static_cast(alloc_dom.size()) - 1; i >= 0; --i) { + if (alloc_dom[i]->isBroadcast()) { + continue; + } + if (!contig[i].has_value() || !*contig[i]) { + return false; + } + found_inner = true; + break; + } + if (!found_inner) { + return false; + } + } + return true; +} + } // namespace nvfuser::scheduler_utils diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index e6da652fdda..509762961a0 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -1025,5 +1025,9 @@ std::pair getRegisterSharing( // and non-stream IterDomains. int64_t countLeadingParallelDimensions(const TensorView*); +// Check if all fusion input TensorViews have compile-time known contiguous +// innermost dimension. TMA requires contiguity to be known at compile time. +bool inputsHaveContiguousInnerDim(Fusion* fusion); + } // namespace scheduler_utils } // namespace nvfuser