Skip to content

Add contiguity checks to TMA schedulers#6024

Open
tbqh wants to merge 3 commits intomainfrom
tbqh/tma_check_contiguity
Open

Add contiguity checks to TMA schedulers#6024
tbqh wants to merge 3 commits intomainfrom
tbqh/tma_check_contiguity

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Mar 3, 2026

TMA requires contiguity to be set at compile time. The TMA schedules were not checking for this. Instead, sometimes they will check vectorization >= 2, which is a runtime value. Symbolic (not marked as contiguous) tensors will fail to compile, even if they are contiguous. ReductionWithIterDim has many such cases, for example:

Error from segmentation group 0: Expected (definition() ->isStrictlyOneOf< ReductionOp, MmaOp, MatmulOp, LinearOp, GroupedMmaOp, ScaledMmaOp>()) . Error rfactoring T2_l_float[rblockIdx.y82{8}, rS83{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, rUR88{8}, rthreadIdx.y89{16}, rblockIdx.x84{( ceilDiv(i2, 128) )}, rthreadIdx.x86{32}, rV87{4}, rS10{i3}, iS11{i4}] because its definition is not a reduction. Exception raised from rFactor at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:804 (most recent call first):

@tbqh tbqh marked this pull request as ready for review March 3, 2026 11:10
@tbqh tbqh requested a review from liqiangxl March 3, 2026 11:10
@github-actions
Copy link

github-actions bot commented Mar 3, 2026

Review updated until commit d49ac87

Description

  • Add compile-time contiguity check for TMA (Tensor Memory Access) schedulers

  • New utility function inputsHaveContiguousInnerDim validates innermost dimension contiguity

  • Update normalization_inner, normalization_inner_outer, pointwise, and reduction schedulers

  • TMA requires contiguity at compile time; runtime checks like vectorization >= 2 are insufficient

Changes walkthrough

Relevant files
Enhancement
utils.cpp
Add inputsHaveContiguousInnerDim utility function               

csrc/scheduler/utils.cpp

  • Added new function inputsHaveContiguousInnerDim(Fusion* fusion) to
    check if all input TensorViews have compile-time known contiguous
    innermost dimension
  • Iterates through fusion inputs, checking contiguity from innermost to
    outermost non-broadcast dimension
  • Returns false if any input lacks contiguous innermost dimension
  • +26/-0   
    utils.h
    Declare inputsHaveContiguousInnerDim function                       

    csrc/scheduler/utils.h

  • Added declaration for inputsHaveContiguousInnerDim function with
    documentation
  • +4/-0     
    Bug fix
    normalization_inner.cpp
    Add contiguity check to normalization_inner mayUseTma       

    csrc/scheduler/normalization_inner.cpp

  • Added contiguity check call to mayUseTma function
  • Returns false if inputs lack contiguous innermost dimension
  • +5/-0     
    normalization_inner_outer.cpp
    Add contiguity check to normalization_inner_outer               

    csrc/scheduler/normalization_inner_outer.cpp

  • Added contiguity check to preferWarpSpecialized function
  • Prevents TMA usage for non-contiguous inputs on Blackwell+ GPUs
  • +4/-0     
    pointwise.cpp
    Add contiguity check to pointwise mayUseTma                           

    csrc/scheduler/pointwise.cpp

  • Modified mayUseTma to accept Fusion* fusion parameter
  • Added contiguity check inside mayUseTma function
  • Updated call site in computeHeuristics to pass fusion
  • +10/-2   
    reduction.cpp
    Add contiguity check to reduction mayUseTma functions       

    csrc/scheduler/reduction.cpp

  • Modified mayUseTma and mayUseTmaOuter to accept Fusion* fusion
    parameter
  • Added contiguity check to both functions
  • Updated call sites in computeHeuristics to pass fusion
  • +12/-2   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Function implementation correctness

    The inputsHaveContiguousInnerDim function implementation looks correct. It properly iterates from the innermost dimension outward, skips broadcast dimensions, and checks for contiguity. However, verify that the function handles edge cases: (1) inputs with only broadcast dimensions (should return false), (2) inputs with multiple consecutive broadcast dimensions followed by a contiguous dimension. The current logic with found_inner flag appears to handle these correctly.

    bool inputsHaveContiguousInnerDim(Fusion* fusion) {
      for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
        const auto& contig = tv->domain()->contiguity();
        if (contig.empty()) {
          return false;
        }
        const auto& alloc_dom = tv->getMaybeAllocationDomain();
        NVF_ERROR(contig.size() == alloc_dom.size());
        bool found_inner = false;
        for (int64_t i = static_cast<int64_t>(alloc_dom.size()) - 1; i >= 0; --i) {
          if (alloc_dom[i]->isBroadcast()) {
            continue;
          }
          if (!contig[i].has_value() || !*contig[i]) {
            return false;
          }
          found_inner = true;
          break;
        }
        if (!found_inner) {
          return false;
        }
      }
      return true;
    }
    Function signature change

    The mayUseTma function signature changed from taking only pointwise_utils::FusionRuntimeProperties& prop to also taking Fusion* fusion. Ensure all call sites are updated correctly. The diff shows the call site at line 440-441 is updated properly. Verify no other call sites exist that would break.

    bool mayUseTma(
        Fusion* fusion,
        const pointwise_utils::FusionRuntimeProperties& prop) {
      // Hardware requirement: Don't use TMA for pre-Hopper GPUs
      if (at::cuda::getCurrentDeviceProperties()->major < 9) {
        return false;
      }
    
      // TMA requires compile-time known contiguous innermost dimension on inputs
      if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
        return false;
      }
      // Check if there are TMA-compatible inputs
      if (!mayHaveTmaCompatibleInputs(prop)) {
        return false;
      }
    Function signature consistency

    Both mayUseTma and mayUseTmaOuter function signatures were updated to include Fusion* fusion parameter. The call sites at lines 321 and 327 are updated correctly. Ensure these changes are consistent with the header file declarations if any.

    bool mayUseTma(
        Fusion* fusion,
        const reduction_scheduler_utils::FusionRuntimeProperties& props) {
      auto dev_prop = at::cuda::getCurrentDeviceProperties();
    
      if (dev_prop->major < 9) {
        return false;
      }
    
      if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
        return false;
      }
    
      // Require the reduction shape is 2D inner reduction: [I, R]
      if (!props.fastest_dim_reduction ||
          props.total_reduction_numel != props.inner_most_dimension_numel) {
        return false;
      }
    
      int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
      uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;
    
      // Minimum TMA transfer size, below which it seems much slower than non-TMA.
      uint64_t min_tma_bytes = 16384;
    
      if (total_reduction_bytes < min_tma_bytes) {
        return false;
      }
    
      // Require reduction dim fits into smem, until we add iteration over large
      // reduction dim.
      const int64_t smem_elems =
          (static_cast<int64_t>(dev_prop->sharedMemPerBlockOptin) * 8) /
          props.max_dtype_size_bit_for_vectorization;
    
      if (props.inner_most_dimension_numel > smem_elems) {
        return false;
      }
    
      // Smem check assumes only one input tensor.
      if (props.n_tensor_inputs != 1) {
        return false;
      }
    
      // Require that the innermost dim is contiguous.
      if (props.vectorize_factor <= 1) {
        return false;
      }
    
      uint64_t vect_bits =
          props.vectorize_factor * props.max_dtype_size_bit_for_vectorization;
    
      // TMA requires 16-byte alignment (128 bits) for memory transactions
      if (vect_bits % 128 != 0) {
        return false;
      }
    
      return true;
    }
    
    bool mayUseTmaOuter(
        Fusion* fusion,
        const reduction_scheduler_utils::FusionRuntimeProperties& props) {
      auto dev_prop = at::cuda::getCurrentDeviceProperties();
    
      if (dev_prop->major < 9) {
        return false;
      }
    
      if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
        return false;
      }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 3, 2026

    Greptile Summary

    This PR fixes a correctness bug where TMA schedulers (pointwise, inner/outer reduction, inner/outer-persistent normalization) would be selected for fusions whose input tensors have no compile-time contiguity annotation on the innermost dimension. Because TMA requires contiguity to be known at compile time, such fusions would fail to compile even though the underlying data might actually be contiguous at runtime.

    The fix introduces a single helper scheduler_utils::inputsHaveContiguousInnerDim that walks each input TensorView's allocation domain from the innermost position outward (skipping broadcast dimensions) and confirms that the first real dimension carries an explicit true contiguity flag. This guard is added at the top of every mayUseTma / preferWarpSpecialized entry point in all four affected schedulers, ensuring fallback to non-TMA paths before any heuristics are computed.

    Key observations:

    • The helper inputsHaveContiguousInnerDim has a logic issue: 0-D tensor inputs (scalar TensorViews with an empty contiguity vector) cause an immediate return false, blocking TMA for the entire fusion even though a 0-D tensor would never be loaded via TMA. Using continue for such tensors would be more correct.
    • The comment "Require that the innermost dim is contiguous." on the vectorize_factor <= 1 checks in both mayUseTma and mayUseTmaOuter is now misleading; those checks are about 128-bit alignment, while the new inputsHaveContiguousInnerDim guard performs the actual compile-time contiguity check.

    Confidence Score: 4/5

    • Safe to merge; the fix is correct and conservative — it only disables TMA when contiguity cannot be verified at compile time. Two minor issues remain unfixed in this PR: 0-D scalar inputs unnecessarily block TMA for mixed scalar+tensor fusions, and misleading comments need updating.
    • The core logic is sound and directly addresses the described compile-time failure. The fix correctly adds compile-time contiguity checks that were missing from the TMA schedulers. One logic issue exists: 0-D scalar TensorView inputs cause inputsHaveContiguousInnerDim to return false and disable TMA for the whole fusion. This does not introduce a correctness regression (TMA was broken in these cases before; non-TMA fallback still works), but it is overly conservative — it may prevent TMA from being used for fusions that have scalar inputs alongside normal tensors. A second issue is documentation: the comments at lines 242 and 290 in reduction.cpp are now misleading about what the vectorize_factor check actually enforces.
    • csrc/scheduler/utils.cpp: Review the 0-D tensor handling in inputsHaveContiguousInnerDim — consider using continue instead of return false to allow TMA when scalar inputs are present. csrc/scheduler/reduction.cpp: Update stale comments in mayUseTma and mayUseTmaOuter to clarify that vectorize_factor checks enforce alignment, not contiguity.

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[computeHeuristics / preferWarpSpecialized] --> B{Hardware check\nSM >= 9 / 10?}
        B -- No --> Z[Return false / fallback]
        B -- Yes --> C{inputsHaveContiguousInnerDim\nnew compile-time check}
        C -- false --> Z
        C -- true --> D{Other runtime checks\nvectorize_factor, smem, etc.}
        D -- fail --> Z
        D -- pass --> E[Use TMA scheduler]
    
        subgraph inputsHaveContiguousInnerDim
            F[For each TensorView input]
            F --> G{contig.empty?}
            G -- Yes --> H[return false ⚠️ 0-D tensors blocked]
            G -- No --> I[Walk alloc_dom from innermost]
            I --> J{isBroadcast?}
            J -- Yes --> I
            J -- No --> K{contig has_value && *contig == true?}
            K -- No --> L[return false]
            K -- Yes --> M[found_inner = true, break]
            M --> N{All TVs passed?}
            N -- Yes --> O[return true]
            N -- No --> L
        end
    
    Loading

    Last reviewed commit: d49ac87

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 3, 2026

    Additional Comments (1)

    csrc/scheduler/reduction.cpp
    The comment "Require that the innermost dim is contiguous." at line 242 is now misleading. The compile-time contiguity check (inputsHaveContiguousInnerDim) is already performed at line 207 earlier in this function. The vectorize_factor <= 1 check here gates on minimum vectorization efficiency for TMA, not on contiguity.

    Consider clarifying the comment to reflect its actual purpose:

      // Require at least 2x vectorization for TMA efficiency.
      if (props.vectorize_factor <= 1) {
        return false;
      }
    

    The same stale comment exists in mayUseTmaOuter at line 290.

    Comment on lines +3773 to +3775
    if (contig.empty()) {
    return false;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    0-D tensor inputs unnecessarily block TMA for the entire fusion.

    When a fusion has a 0-D (scalar) TensorView input — e.g. a scale factor — contig.empty() is true and the function immediately returns false, disabling TMA for the entire fusion. A 0-D tensor has no innermost memory dimension and would never be loaded via TMA, so it should be skipped rather than treated as a failure.

    The same effect recurs later: when contig is empty, alloc_dom is also empty, the inner loop never executes, found_inner stays false, and return false is hit again — so the early guard is redundant with the existing !found_inner check.

    Consider skipping 0-D inputs instead:

    Suggested change
    if (contig.empty()) {
    return false;
    }
    if (contig.empty()) {
    // 0-D tensor has no inner dim to load via TMA; skip it.
    continue;
    }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 4, 2026

    Additional Comments (1)

    csrc/scheduler/reduction.cpp
    Stale comment after compile-time contiguity check was separated.

    Now that compile-time contiguity is verified up-front by inputsHaveContiguousInnerDim (line 207), this comment is misleading. props.vectorize_factor is a runtime value derived from alignment, not a compile-time contiguity marker. The same misleading comment appears identically in mayUseTmaOuter (line 290).

    Update to reflect what this guard actually checks:

      // Require vectorization for TMA 128-bit alignment.
      if (props.vectorize_factor <= 1) {
        return false;
      }
    

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant