Skip to content

add auto tma transpose scheduler#6018

Open
liqiangxl wants to merge 7 commits intomainfrom
llu/transpose_output_smem_auto
Open

add auto tma transpose scheduler#6018
liqiangxl wants to merge 7 commits intomainfrom
llu/transpose_output_smem_auto

Conversation

@liqiangxl
Copy link
Collaborator

To reduce number of tranpose ops, is_output_smem_transpose is added to control input/output transpose:

1. When there are more inputs than outputs, is_output_smem_transpose = True
TMA load without swizzle, TMA store with swizzle, transpose at regs --> output cached smem

2. When there are less inputs than outputs, is_output_smem_transpose = False
TMA load with swizzle, register store, transpose at input cached smem -> regs

Current performance is in this doc.

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review February 27, 2026 15:40
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR implements the auto TMA transpose scheduler, introducing a new TmaTranspose enable option and replacing the previous stub with a fully functional getTransposeHeuristics / scheduleTranspose pair in transpose_tma.cpp. The scheduler selects between two transpose strategies based on input/output tensor counts: when there are more inputs than outputs (is_output_smem_transpose = true), TMA store with swizzle is used on the output side; otherwise TMA load with swizzle operates on the input side. Supporting changes include new fields on TransposeParams, a C++20 ranges refactor in tma.cpp, and std::uint8_t underlying types on all option enums.

Key issues found:

  • Infinite loop in getTransposeHeuristics: if bytes_per_cta / n_input < kTmaSwizzleBytes (128), integer division produces estimated_tile_size1 = 0; the while loop multiplying by 2 never converges (0 * 2 == 0). This can be triggered when n_input > 64 on certain GPU configs.
  • Missing guard in scheduleTranspose: tma_store_tvs.at(0) is accessed whenever is_output_smem_transpose == true without first asserting use_tma_store == true, unlike the symmetric check already present for the input-smem path.
  • Debug std::cout left in OutputTransposeBankconflict test — fires on test failure and should use GTest facilities instead.
  • Typos ("tranapose" → "transpose") in comment lines describing test parameter combinations.

Confidence Score: 3/5

  • PR is functionally correct for typical 1–2 input transpose fusions, but contains an infinite-loop risk and a potential crash for less common parameter combinations.
  • The PR introduces the TMA transpose scheduler with correct overall logic for standard use cases (1–2 inputs). However, two logic issues lower confidence:
  1. Infinite loop when estimated_tile_size1 = 0: This occurs when n_input > 64 on SM90, a realistic edge case for more complex fusion patterns. The while loop condition guarantees termination only if estimated_tile_size1 > 0.

  2. Unguarded tma_store_tvs.at(0) access: While the default heuristic ties is_output_smem_transpose and use_tma_store together, the API allows independent override (as shown in parameterized tests). Missing the guard that exists in the symmetric input-smem path is an asymmetry that could cause crashes if those params are misused.

Neither issue is triggered by the test suite as written, but both represent real failure modes for edge-case inputs or manually-constructed params. The code quality issues (test cout and typos) are minor.

  • csrc/scheduler/transpose_tma.cpp — review the tile-size heuristic initialization and the is_output_smem_transpose / use_tma_store constraint in scheduleTranspose.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
    B -- Yes --> C[tma::getTransposeHeuristics]
    B -- No --> E[non_tma::getTransposeHeuristics]
    C --> D{n_input > n_output?}
    D -- Yes\nis_output_smem_transpose=true --> F[use_tma_load=true\nuse_tma_store=true\nswizzle on output smem]
    D -- No\nis_output_smem_transpose=false --> G[use_tma_load=true\nuse_tma_store=false\nswizzle on input smem]
    F --> H[Return TransposeParams]
    G --> H
    C -- returns null --> E
    E --> H
    H --> I[TransposeScheduler::schedule]
    I --> J{use_tma_load OR use_tma_store?}
    J -- Yes --> K[tma::scheduleTranspose]
    J -- No --> L[non_tma::scheduleTranspose]
    K --> M{is_output_smem_transpose?}
    M -- true --> N[TMA load w/o swizzle\nTMA store w/ MmaSwizzle\ntranspose at regs→output smem]
    M -- false --> O[TMA load w/ MmaSwizzle\nregister store\ntranspose at input smem→regs]
Loading

Last reviewed commit: bc772db

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

NVF_ERROR(grouped_inputs_outputs.size() >= 2);

// When there are more inputs than outputs, output smem transpose should be
// used, however, if it is not, then input smem tranpose will be used, to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tranpose should be transpose

const int64_t cta_per_sm =
dev_props->maxThreadsPerMultiProcessor / threads_per_cta;
const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm;
const int64_t bytes_per_tile = bytes_per_cta / n_input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.

Suggested change
const int64_t bytes_per_tile = bytes_per_cta / n_input;
NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose");
const int64_t bytes_per_tile = bytes_per_cta / n_input;

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@liqiangxl liqiangxl requested a review from rdspring1 February 27, 2026 17:24
@github-actions
Copy link

github-actions bot commented Mar 2, 2026

Review updated until commit bc772db

Description

  • Implements automatic TMA (Tensor Memory Access) transpose scheduler with two paths: input smem transpose (swizzle on input) and output smem transpose (swizzle on output)

  • Adds new TmaTranspose enable option to toggle the feature; scheduler falls back to non-TMA when disabled

  • Introduces new parameters: use_tma_store, is_output_smem_transpose, chunks_per_thread, elements_per_chunk for flexible TMA configuration

  • Adds comprehensive tests covering different dtypes, transpose dimensions, and TMA parameter combinations

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Potential TMA load restriction

The new code filters loop domains to only include non-trivial IDs (extent > 1 or non-const) before checking for thread/serial dims.
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads
where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break
existing TMA use cases.

auto non_trivial_ids =
    tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
      return !id->extent()->isConstScalar() ||
          id->extent()->evaluate().as<int64_t>() > 1;
    });
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
      return id->isThreadDim() ||
          id->getParallelType() == ParallelType::Serial;
    })) {
  return {};
}
Missing null check

In scheduleTranspose, when setting up TMA store (lines 165-172), the code accesses fusion->outputs()[output_idx] without
checking if output_idx is within bounds. While cached_outputs should correspond to outputs, a bounds check would be safer.

for (auto [cached_output, output_idx] : cached_outputs) {
  auto output = fusion->outputs()[output_idx]->as<TensorView>();
  output->definition()->as<LoadStoreOp>()->setOpType(
      LoadStoreOpType::CpAsyncBulkTensorTile);
  cached_output->setMemoryType(MemoryType::Shared);
  cached_output->cacheBefore();
  tma_store_tvs.push_back(cached_output);
}
Thread safety consideration

The copy constructor was modified to use a lambda that captures other.mutex_ and returns other.options_. While this appears
correct, the original implementation directly assigned options_. The new approach should be verified to maintain the same
thread-safety semantics under concurrent access patterns.

Options(const Options& other)
    : options_([&other]() {
        std::lock_guard<std::mutex> lock_other(other.mutex_);
        return other.options_;
      }()) {}

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (2)

csrc/scheduler/transpose_tma.cpp, line 106
Infinite loop when estimated_tile_size1 starts at zero

If bytes_per_tile < kTmaSwizzleBytes (line 91-92), integer division yields estimated_tile_size1 = 0. The while loop (line 104) then spins forever because 0 * 2 == 0 and get_chunks_per_thread() (line 98-102) stays at 0, which is always less than min_chunks_per_thread = 4.

On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when n_input > 64. Add an initialization guard before the loop:

  // Ensure we start from at least 1 to avoid multiplying 0 forever.
  if (estimated_tile_size1 == 0) {
    estimated_tile_size1 = 1;
  }
  while (get_chunks_per_thread() < min_chunks_per_thread) {
    estimated_tile_size1 *= 2;
  }

tests/cpp/test_transpose.cpp, line 1947
Unconditional debug output will pollute test logs

The std::cout block (lines 1945–1947) prints every bank conflict unconditionally. This makes test runner output noisy, especially since the BFloat16 path is expected to have bank conflicts. Consider wrapping the print in a debug flag or removing it:

      if (auto* ke = dynamic_cast<KernelExecutor*>(executor.get())) {
        auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
        if (dtype == DataType::Float) {
          EXPECT_TRUE(bank_conflicts.empty());
        } else {
          // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved.
          EXPECT_FALSE(bank_conflicts.empty());
        }
      }

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

csrc/scheduler/transpose_tma.cpp, line 107
Potential infinite loop when estimated_tile_size1 initializes to zero

If bytes_per_tile < kTmaSwizzleBytes (128), integer division yields estimated_tile_size1 = 0. The while loop then evaluates get_chunks_per_thread() as 0 (because the numerator is 0 * tile_size2 = 0) and multiplies: 0 * 2 = 0 — the loop never terminates.

This happens when bytes_per_cta / n_input < 128. With an SM90 GPU (maxThreadsPerMultiProcessor = 2048), cta_per_sm = 8, giving bytes_per_cta = 8192. So the loop infinite-hangs when n_input > 64.

While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise estimated_tile_size1 to at least 1:

int64_t estimated_tile_size1 =
    std::max(int64_t(1), bytes_per_tile / kTmaSwizzleBytes);

csrc/scheduler/transpose_tma.cpp, line 267
Missing guard before accessing tma_store_tvs when use_tma_store may be false

tma_store_tvs is only populated when tparams->use_tma_store == true (lines 164–173), but this block checks only tparams->is_output_smem_transpose. If is_output_smem_transpose = true but use_tma_store = false, then tma_store_tvs will be empty and .at(0) throws std::out_of_range.

Note the asymmetry: Step 3 already guards the analogous constraint with an explicit NVF_ERROR(tparams->use_tma_load, ...) at line 286-288. Adding the same guard here would be consistent:

if (tparams->is_output_smem_transpose) {
    NVF_ERROR(
        tparams->use_tma_store,
        "TMA store must be used when output smem is transposed");
    MmaInputSmemSwizzle swizzle =
        mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));

tests/cpp/test_transpose.cpp, line 1949
Debug std::cout in test code — use GTest facilities instead

These std::cout lines will only fire when bank conflicts are detected (when the test is already failing). However, raw std::cout in tests is unconventional — GTest's ADD_FAILURE() / SCOPED_TRACE or just the EXPECT_TRUE failure message would be more idiomatic:

      for (auto& [expr, ways] : bank_conflicts) {
        auto [read_ways, write_ways] = ways;
        ADD_FAILURE() << "Bank conflict in: " << expr->toString()
                      << "  read=" << read_ways << "-way"
                      << ", write=" << write_ways << "-way";
      }

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


tests/cpp/test_transpose.cpp, line 1969
Typo "tranapose" should be "transpose" in multiple lines

// Test different combinations of TMA transpose parameters:
// (is_output_smem, use_tma_load, use_tma_store)
//   (false, true, false)  - input smem transpose, TMA load only
//   (false, true, true)   - input smem transpose, TMA load + TMA store
//   (true,  true, true)   - output smem transpose, TMA load + TMA store
//   (true,  false, true)  - output smem transpose, TMA store only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant