Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions csrc/scheduler/normalization_inner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ bool mayUseTma(Fusion* fusion, const PersistentKernelProperties& prop) {
return false;
}

// TMA requires compile-time known contiguous innermost dimension on inputs
if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}

// TMA requires 16-byte alignment (128 bits) for memory transactions
if (prop.vectorize_factor * prop.max_dtype_size_bit % 128 != 0) {
return false;
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/normalization_inner_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ bool preferWarpSpecialized(
if (at::cuda::getCurrentDeviceProperties()->major < 10) {
return false;
}
// TMA requires compile-time known contiguous innermost dimension on inputs
if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}
// False, if any of the inputs is dynamically shaped
// TODO: extend to support dynamic inputs, warp specialization requires
// static CTA size
Expand Down
12 changes: 10 additions & 2 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,18 @@ bool mayHaveTmaCompatibleInputs(
// serves as a fast path to avoid computing full heuristics if TMA is clearly
// not applicable. Passing this check does not guarantee that TMA will be used;
// the final decision is made during heuristics computation.
bool mayUseTma(const pointwise_utils::FusionRuntimeProperties& prop) {
bool mayUseTma(
Fusion* fusion,
const pointwise_utils::FusionRuntimeProperties& prop) {
// Hardware requirement: Don't use TMA for pre-Hopper GPUs
if (at::cuda::getCurrentDeviceProperties()->major < 9) {
return false;
}

// TMA requires compile-time known contiguous innermost dimension on inputs
if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}
// Check if there are TMA-compatible inputs
if (!mayHaveTmaCompatibleInputs(prop)) {
return false;
Expand All @@ -430,7 +437,8 @@ std::unique_ptr<HeuristicParams> PointWiseScheduler::computeHeuristics(
}
const auto& prop = prop_opt.value();

bool use_tma = mayUseTma(prop) && isOptionEnabled(EnableOption::TmaPointwise);
bool use_tma = mayUseTma(fusion, prop) &&
isOptionEnabled(EnableOption::TmaPointwise);
std::unique_ptr<HeuristicParams> pparams = nullptr;
if (use_tma) {
pparams = pointwise::tma::getPointwiseHeuristics(
Expand Down
14 changes: 12 additions & 2 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,18 @@ bool ReductionScheduler::canScheduleRunTime(
namespace {

bool mayUseTma(
Fusion* fusion,
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
auto dev_prop = at::cuda::getCurrentDeviceProperties();

if (dev_prop->major < 9) {
return false;
}

if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}

// Require the reduction shape is 2D inner reduction: [I, R]
if (!props.fastest_dim_reduction ||
props.total_reduction_numel != props.inner_most_dimension_numel) {
Expand Down Expand Up @@ -251,13 +256,18 @@ bool mayUseTma(
}

bool mayUseTmaOuter(
Fusion* fusion,
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
auto dev_prop = at::cuda::getCurrentDeviceProperties();

if (dev_prop->major < 9) {
return false;
}

if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}

// Require outer reduction
if (props.fastest_dim_reduction) {
return false;
Expand Down Expand Up @@ -308,13 +318,13 @@ std::unique_ptr<HeuristicParams> ReductionScheduler::computeHeuristics(
std::unique_ptr<HeuristicParams> rparams = nullptr;

// Try outer TMA scheduler for outer reductions
if (tma_enabled && mayUseTmaOuter(props)) {
if (tma_enabled && mayUseTmaOuter(fusion, props)) {
rparams = reduction::outer_tma::getReductionHeuristics(
fusion, runtime_info, data_cache, props);
}

// Try inner TMA scheduler for inner reductions
if (rparams == nullptr && tma_enabled && mayUseTma(props)) {
if (rparams == nullptr && tma_enabled && mayUseTma(fusion, props)) {
rparams = reduction::tma::getReductionHeuristics(
fusion, runtime_info, data_cache, props);
}
Expand Down
26 changes: 26 additions & 0 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3767,4 +3767,30 @@ int64_t countLeadingParallelDimensions(const TensorView* tv) {
return num_parallel_dims;
}

bool inputsHaveContiguousInnerDim(Fusion* fusion) {
for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
const auto& contig = tv->domain()->contiguity();
if (contig.empty()) {
return false;
}
Comment on lines +3773 to +3775
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0-D tensor inputs unnecessarily block TMA for the entire fusion.

When a fusion has a 0-D (scalar) TensorView input — e.g. a scale factor — contig.empty() is true and the function immediately returns false, disabling TMA for the entire fusion. A 0-D tensor has no innermost memory dimension and would never be loaded via TMA, so it should be skipped rather than treated as a failure.

The same effect recurs later: when contig is empty, alloc_dom is also empty, the inner loop never executes, found_inner stays false, and return false is hit again — so the early guard is redundant with the existing !found_inner check.

Consider skipping 0-D inputs instead:

Suggested change
if (contig.empty()) {
return false;
}
if (contig.empty()) {
// 0-D tensor has no inner dim to load via TMA; skip it.
continue;
}

const auto& alloc_dom = tv->getMaybeAllocationDomain();
NVF_ERROR(contig.size() == alloc_dom.size());
bool found_inner = false;
for (int64_t i = static_cast<int64_t>(alloc_dom.size()) - 1; i >= 0; --i) {
if (alloc_dom[i]->isBroadcast()) {
continue;
}
if (!contig[i].has_value() || !*contig[i]) {
return false;
}
found_inner = true;
break;
}
if (!found_inner) {
return false;
}
}
return true;
}

} // namespace nvfuser::scheduler_utils
4 changes: 4 additions & 0 deletions csrc/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1025,5 +1025,9 @@ std::pair<int64_t, int64_t> getRegisterSharing(
// and non-stream IterDomains.
int64_t countLeadingParallelDimensions(const TensorView*);

// Check if all fusion input TensorViews have compile-time known contiguous
// innermost dimension. TMA requires contiguity to be known at compile time.
bool inputsHaveContiguousInnerDim(Fusion* fusion);

} // namespace scheduler_utils
} // namespace nvfuser
Loading