Conversation
|
Review updated until commit d49ac87 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Bug fix |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Function implementation correctness
inputsHaveContiguousInnerDim function implementation looks correct. It properly iterates from the innermost dimension outward, skips broadcast dimensions, and checks for contiguity. However, verify that the function handles edge cases: (1) inputs with only broadcast dimensions (should return false), (2) inputs with multiple consecutive broadcast dimensions followed by a contiguous dimension. The current logic with found_inner flag appears to handle these correctly. |
Greptile SummaryThis PR fixes a correctness bug where TMA schedulers (pointwise, inner/outer reduction, inner/outer-persistent normalization) would be selected for fusions whose input tensors have no compile-time contiguity annotation on the innermost dimension. Because TMA requires contiguity to be known at compile time, such fusions would fail to compile even though the underlying data might actually be contiguous at runtime. The fix introduces a single helper Key observations:
Confidence Score: 4/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[computeHeuristics / preferWarpSpecialized] --> B{Hardware check\nSM >= 9 / 10?}
B -- No --> Z[Return false / fallback]
B -- Yes --> C{inputsHaveContiguousInnerDim\nnew compile-time check}
C -- false --> Z
C -- true --> D{Other runtime checks\nvectorize_factor, smem, etc.}
D -- fail --> Z
D -- pass --> E[Use TMA scheduler]
subgraph inputsHaveContiguousInnerDim
F[For each TensorView input]
F --> G{contig.empty?}
G -- Yes --> H[return false ⚠️ 0-D tensors blocked]
G -- No --> I[Walk alloc_dom from innermost]
I --> J{isBroadcast?}
J -- Yes --> I
J -- No --> K{contig has_value && *contig == true?}
K -- No --> L[return false]
K -- Yes --> M[found_inner = true, break]
M --> N{All TVs passed?}
N -- Yes --> O[return true]
N -- No --> L
end
Last reviewed commit: d49ac87 |
Additional Comments (1)
Consider clarifying the comment to reflect its actual purpose: The same stale comment exists in |
| if (contig.empty()) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
0-D tensor inputs unnecessarily block TMA for the entire fusion.
When a fusion has a 0-D (scalar) TensorView input — e.g. a scale factor — contig.empty() is true and the function immediately returns false, disabling TMA for the entire fusion. A 0-D tensor has no innermost memory dimension and would never be loaded via TMA, so it should be skipped rather than treated as a failure.
The same effect recurs later: when contig is empty, alloc_dom is also empty, the inner loop never executes, found_inner stays false, and return false is hit again — so the early guard is redundant with the existing !found_inner check.
Consider skipping 0-D inputs instead:
| if (contig.empty()) { | |
| return false; | |
| } | |
| if (contig.empty()) { | |
| // 0-D tensor has no inner dim to load via TMA; skip it. | |
| continue; | |
| } |
Additional Comments (1)
Now that compile-time contiguity is verified up-front by Update to reflect what this guard actually checks: |
TMA requires contiguity to be set at compile time. The TMA schedules were not checking for this. Instead, sometimes they will check
vectorization >= 2, which is a runtime value. Symbolic (not marked as contiguous) tensors will fail to compile, even if they are contiguous.ReductionWithIterDimhas many such cases, for example:Error from segmentation group 0: Expected (definition() ->isStrictlyOneOf< ReductionOp, MmaOp, MatmulOp, LinearOp, GroupedMmaOp, ScaledMmaOp>()) . Error rfactoring T2_l_float[rblockIdx.y82{8}, rS83{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, rUR88{8}, rthreadIdx.y89{16}, rblockIdx.x84{( ceilDiv(i2, 128) )}, rthreadIdx.x86{32}, rV87{4}, rS10{i3}, iS11{i4}] because its definition is not a reduction. Exception raised from rFactor at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:804 (most recent call first):