From cd7767934abd05fff1a677c2e43e1a73e7e64001 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 27 Feb 2026 07:26:50 -0800 Subject: [PATCH 1/8] add auto tma transpose scheduler --- csrc/device_lower/analysis/tma.cpp | 16 +- csrc/options.cpp | 3 +- csrc/options.h | 19 +- csrc/scheduler/transpose.cpp | 10 +- csrc/scheduler/transpose_heuristic.h | 32 ++- csrc/scheduler/transpose_tma.cpp | 351 ++++++++++++++++++++++++++- tests/cpp/test_transpose.cpp | 76 ++++++ 7 files changed, 482 insertions(+), 25 deletions(-) diff --git a/csrc/device_lower/analysis/tma.cpp b/csrc/device_lower/analysis/tma.cpp index 3dedd9413f1..b1d45fa2a63 100644 --- a/csrc/device_lower/analysis/tma.cpp +++ b/csrc/device_lower/analysis/tma.cpp @@ -53,13 +53,15 @@ std::unordered_set getBatchableTmaLoads( // We have some tests where TMA load is used in an untraditional way. // e.g. parallelized with threads, serial load, which requires multiple // mbarriers or reuse of the same mbarrier. - if (std::any_of( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - [](const IterDomain* id) { - return id->isThreadDim() || - id->getParallelType() == ParallelType::Serial; - })) { + auto non_trivial_ids = + tv->getLoopDomain() | std::views::filter([](const IterDomain* id) { + return !id->extent()->isConstScalar() || + id->extent()->evaluate().as() > 1; + }); + if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) { + return id->isThreadDim() || + id->getParallelType() == ParallelType::Serial; + })) { return {}; } non_cb_tma_load_exprs.push_back(expr); diff --git a/csrc/options.cpp b/csrc/options.cpp index 14dddd89eec..775f76f26ea 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -40,7 +40,7 @@ auto parseEnvOptions( available_options.end(), std::back_inserter(option_values), [](const auto& kv) { return kv.first; }); - std::sort(option_values.begin(), option_values.end()); + std::ranges::sort(option_values); NVF_CHECK( false, "Parsing ", @@ -174,6 +174,7 @@ const std::unordered_map& getEnableOptions() { {"tma_pointwise", EnableOption::TmaPointwise}, {"tma_inner_persistent", EnableOption::TmaInnerPersistent}, {"tma_reduction", EnableOption::TmaReduction}, + {"tma_transpose", EnableOption::TmaTranspose}, {"ws_normalization", EnableOption::WarpSpecializedNormalization}, {"host_ir_lowering", EnableOption::HostIrLowering}, {"host_ir_jit", EnableOption::HostIrJit}, diff --git a/csrc/options.h b/csrc/options.h index 6dad3909997..9a9063c2f3b 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -22,7 +23,7 @@ namespace nvfuser { //! //! These can be set through the `NVFUSER_DUMP` environment variable //! -enum class DebugDumpOption { +enum class DebugDumpOption : std::uint8_t { CutlassCompile, //!< Dump compile commands and compile times for //!< CutlassExecutor FunctionTrace, //!< Dump the function trace of selected internal function. The @@ -97,7 +98,7 @@ enum class DebugDumpOption { //! //! These can be set through the `NVFUSER_ENABLE` environment variable //! -enum class EnableOption { +enum class EnableOption : std::uint8_t { CutlassScheduler, //! Enable the CUTLASS scheduler and executor FuseMatmul, //! Enable automatic fusion of matmul and linear ops FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel @@ -118,6 +119,7 @@ enum class EnableOption { TmaPointwise, //! Enable TMA pointwise kernel TmaInnerPersistent, //! Enable TMA inner persistent kernel TmaReduction, //! Enable TMA reduction kernel + TmaTranspose, //! Enable TMA transpose kernel WarpSpecializedNormalization, //! Enable warp specialized persistent kernel HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR HostIrJit, //! Enable Host IR JIT compilation with LLVM @@ -134,7 +136,7 @@ enum class EnableOption { //! //! These can be set through the `NVFUSER_DISABLE` environment variable //! -enum class DisableOption { +enum class DisableOption : std::uint8_t { CompileToSass, //! Disable direct compilation to sass so the ptx can be //! examined ContigIndexing, //! Disable contiguous indexing @@ -176,7 +178,7 @@ enum class DisableOption { //! //! These can be set through the `NVFUSER_PROF` environment variable //! -enum class ProfilerOption { +enum class ProfilerOption : std::uint8_t { Enable, //! Enables the profiler. EnableNocupti, //! Enables the profiler, but disables CUPTI specific //! profiling inorder to measure true host time without @@ -197,10 +199,11 @@ class Options { public: Options() : options_(getOptionsFromEnv()) {} - Options(const Options& other) { - std::lock_guard lock_other(other.mutex_); - options_ = other.options_; - } + Options(const Options& other) + : options_([&other]() { + std::lock_guard lock_other(other.mutex_); + return other.options_; + }()) {} Options& operator=(const Options& other) { std::lock_guard lock_other(other.mutex_); diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 8dacde8d3b1..5487d59493a 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -407,9 +407,11 @@ std::unique_ptr TransposeScheduler::computeHeuristics( std::unique_ptr tparams = nullptr; - // Try TMA path first - tparams = - transpose::tma::getTransposeHeuristics(fusion, runtime_info, data_cache); + // Try TMA path first if enabled + if (isOptionEnabled(EnableOption::TmaTranspose)) { + tparams = transpose::tma::getTransposeHeuristics( + fusion, runtime_info, data_cache); + } // Fallback to non-TMA scheduler if TMA is not applicable if (tparams == nullptr) { @@ -431,7 +433,7 @@ void TransposeScheduler::schedule( "Incorrect parameters sent to TransposeScheduler::schedule", params); - if (tparams->use_tma_load) { + if (tparams->use_tma_load || tparams->use_tma_store) { transpose::tma::scheduleTranspose(fusion, tparams); } else { transpose::non_tma::scheduleTranspose(fusion, tparams); diff --git a/csrc/scheduler/transpose_heuristic.h b/csrc/scheduler/transpose_heuristic.h index 96db01a6410..838f5df2d25 100644 --- a/csrc/scheduler/transpose_heuristic.h +++ b/csrc/scheduler/transpose_heuristic.h @@ -22,7 +22,7 @@ namespace nvfuser { // are equivelent! class TransposeParams : public HeuristicParams { public: - TransposeParams() : HeuristicParams(SchedulerType::Transpose) {}; + TransposeParams() : HeuristicParams(SchedulerType::Transpose){}; static constexpr int64_t getMaxThreadsPerBlock() { return 128; } @@ -39,6 +39,20 @@ class TransposeParams : public HeuristicParams { // Whether to use TMA for loading inputs bool use_tma_load = false; + bool use_tma_store = false; + + // Which side of shared memory holds the transposed (swizzled) layout. + // false: input smem is swizzled, transpose happens on smem->register read. + // true: output smem is swizzled, transpose happens on register->smem write. + // This is independent of use_tma_load/use_tma_store — TMA can be used for + // either side regardless of where the transpose swizzle lives. + bool is_output_smem_transpose = false; + + // In 128-bytes swizzled tma load, inner most dim is split into 8 chunks each + // with 16 bytes. Each thread may handle multiple chunks along the inner most + // dim. + int64_t chunks_per_thread = 1; + int64_t elements_per_chunk = 1; // Vectorization factor for tensors in the first group int64_t vectorize_factor1 = 1; @@ -65,6 +79,10 @@ class TransposeParams : public HeuristicParams { } bool attr_equal = other->cparams == cparams && other->use_tma_load == use_tma_load && + other->use_tma_store == use_tma_store && + other->is_output_smem_transpose == is_output_smem_transpose && + other->chunks_per_thread == chunks_per_thread && + other->elements_per_chunk == elements_per_chunk && other->split_before_tiling == split_before_tiling && other->dims_merged_with_1 == dims_merged_with_1 && other->dims_merged_with_2 == dims_merged_with_2 && @@ -99,6 +117,14 @@ class TransposeParams : public HeuristicParams { if (unroll_factor2 > 1) { ss << "Unroll group 2, Factor: " << unroll_factor2 << "\n"; } + if (use_tma_load || use_tma_store) { + ss << "TMA: load=" << (use_tma_load ? "true" : "false") + << " store=" << (use_tma_store ? "true" : "false") + << " is_output_smem_transpose=" + << (is_output_smem_transpose ? "true" : "false") + << " chunks_per_thread=" << chunks_per_thread + << " elements_per_chunk=" << elements_per_chunk << "\n"; + } if (!split_before_tiling.empty() || !dims_merged_with_1.empty() || !dims_merged_with_2.empty()) { ss << "Virtual inner-most dim:\n"; @@ -146,6 +172,10 @@ class TransposeParams : public HeuristicParams { size_t hash() const override { return c10::get_hash( use_tma_load, + use_tma_store, + is_output_smem_transpose, + chunks_per_thread, + elements_per_chunk, split_before_tiling, dims_merged_with_1, dims_merged_with_2, diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index d4f2a64ca40..caa72c19327 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -8,20 +8,363 @@ #include "scheduler/transpose_tma.h" +#include +#include "ir/utils.h" +#include "scheduler/mma_utils.h" +#include "scheduler/runtime_info.h" +#include "scheduler/tools/domain_map.h" +#include "scheduler/tools/inlining.h" +#include "scheduler/utils.h" +#include "transform_replay.h" +#include "type.h" namespace nvfuser { namespace transpose { namespace tma { - +constexpr int64_t kBytesPerChunk = 16; +constexpr int64_t kTmaSwizzleBytes = 128; std::unique_ptr getTransposeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicDataCache* data_cache) { - // TMA transpose scheduling is not yet implemented. - return nullptr; + auto tparams = std::make_unique(); + tparams->tag = "TMA Transpose heuristics"; + tparams->cparams.index_type = runtime_info.getIndexType(); + + int64_t max_input_dtype_size = 1; + int64_t n_input = 0; + for (auto inp : ir_utils::filterByType(fusion->inputs())) { + max_input_dtype_size = std::max( + max_input_dtype_size, + dataTypeSizeByte(valueOrError(inp->getDataType()))); + n_input++; + } + + int64_t max_output_dtype_size = 1; + int64_t n_output = 0; + for (auto out : ir_utils::filterByType(fusion->outputs())) { + max_output_dtype_size = std::max( + max_output_dtype_size, + dataTypeSizeByte(valueOrError(out->getDataType()))); + n_output++; + } + + // Choose between input smem transpose and output smem transpose. + // Input smem transpose: transpose happens when reading from swizzled input + // shared memory to registers. columns to rows. + // Output smem transpose: transpose happens when writing from registers to + // swizzled output shared memory. rows to columns. + // + // Use input smem transpose when inputs <= outputs to reduce smem usage and + // swizzle cost (fewer inputs to swizzle). Use output smem transpose when + // inputs > outputs for the same reason (fewer outputs to swizzle). + // + // TMA load/store are independent of the transpose direction: + // - TMA load stages inputs in shared memory (always beneficial for input + // smem transpose since inputs need swizzled smem anyway). + // - TMA store writes outputs from shared memory (always beneficial for + // output smem transpose since outputs need swizzled smem anyway). + tparams->is_output_smem_transpose = n_input > n_output; + tparams->use_tma_load = true; + tparams->use_tma_store = tparams->is_output_smem_transpose; + + // Inputs and outputs are grouped into two groups based on their inner most + // dim. The group with smaller number of tvs is swizzled in shared memory. TMA + // scheduler assumes all inputs are in the same group and all outputs are in + // the same group. + // tile_size2 is the tile size for the inner most dim of the group with + // swizzled smem (group2), it should follow the restriction of TMA swizzle + // size. + int64_t swizzle_dtype_size = tparams->is_output_smem_transpose + ? max_output_dtype_size + : max_input_dtype_size; + int64_t constrained_tile = kTmaSwizzleBytes / swizzle_dtype_size; + tparams->tile_size2 = constrained_tile; + + // Fixed, 16 bytes per chunk for swizzle, 4 float32 or 8 float16 elements. + tparams->elements_per_chunk = kBytesPerChunk / swizzle_dtype_size; + + // free dim that can be tuned, increase tile size when input count is small + // assuming issue 64KB loading data per sm, 256 threads per cta. + auto dev_props = at::cuda::getCurrentDeviceProperties(); + constexpr int64_t bytes_per_sm = 64 * 1024; + constexpr int64_t threads_per_cta = 256; + const int64_t cta_per_sm = + dev_props->maxThreadsPerMultiProcessor / threads_per_cta; + const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; + const int64_t bytes_per_tile = bytes_per_cta / n_input; + int64_t estimated_tile_size1 = bytes_per_tile / kTmaSwizzleBytes; + + // tile1 * tile2 = elements_per_chunk * chunks_per_thread * threads_per_cta + // May further increase tile1 to ensure each thread has at least 4 chunks to + // process for better efficiency. + constexpr int64_t min_chunks_per_thread = 4; + auto get_chunks_per_thread = [&]() { + int64_t elements_per_thread = + estimated_tile_size1 * tparams->tile_size2 / threads_per_cta; + return elements_per_thread / tparams->elements_per_chunk; + }; + while (get_chunks_per_thread() < min_chunks_per_thread) { + estimated_tile_size1 *= 2; + } + tparams->tile_size1 = estimated_tile_size1; + tparams->chunks_per_thread = get_chunks_per_thread(); + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== TMA Transpose Stats ========\n" + << "inputs: " << ir_utils::toString(fusion->inputs()) << "\n" + << "outputs: " << ir_utils::toString(fusion->outputs()) << "\n" + << "is_output_smem_transpose: " << tparams->is_output_smem_transpose + << "\n" + << "use_tma_load: " << tparams->use_tma_load << "\n" + << "use_tma_store: " << tparams->use_tma_store << "\n" + << "tile_size1: " << tparams->tile_size1 << "\n" + << "tile_size2: " << tparams->tile_size2 << "\n" + << "chunks_per_thread: " << tparams->chunks_per_thread << "\n" + << "elements_per_chunk: " << tparams->elements_per_chunk << "\n" + << "\n"; + } + return tparams; } void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { - NVF_THROW("TMA transpose scheduling is not yet implemented."); + FusionGuard fg(fusion); + + scheduler_utils::clearMemorySpace(fusion); + + NVF_ERROR( + !ir_utils::hasAnyReductionOps(fusion), + "This scheduler only handles pointwise ops."); + + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + + scheduler_utils::prepareForMemoryTypePromotion(fusion); + + // Set up TMA load for inputs. + std::vector tma_load_tvs; + for (auto [cached_input, input_idx] : cached_inputs) { + if (!tparams->use_tma_load) { + continue; + } + auto load_op = dynamic_cast(cached_input->definition()); + if (load_op == nullptr) { + continue; + } + load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); + cached_input->setMemoryType(MemoryType::Shared); + cached_input->cacheAfter(); + tma_load_tvs.push_back(cached_input); + } + + // Set up output caching with TMA store when enabled. + std::vector tma_store_tvs; + std::vector output_tvs; + for (auto [cached_output, output_idx] : cached_outputs) { + auto output = fusion->outputs()[output_idx]->as(); + output_tvs.push_back(output); + if (!tparams->use_tma_store) { + continue; + } + output->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + cached_output->setMemoryType(MemoryType::Shared); + cached_output->cacheBefore(); + tma_store_tvs.push_back(cached_output); + } + + // Find transposed ids and positions, two groups, transpose happens in + // group-2's cached smem. + scheduler_tools::TransposeDomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + NVF_ERROR(grouped_inputs_outputs.size() >= 2); + + // When there are more inputs than outputs, output smem transpose should be + // used, however, if it is not, then input smem tranpose will be used, to + // ensure group2 is always the one that is transposed, we should swap group1 + // and group2. + if (!tparams->is_output_smem_transpose && + cached_inputs.size() > cached_outputs.size()) { + std::swap(grouped_inputs_outputs[0], grouped_inputs_outputs[1]); + } + + TensorView* reference1 = + domain_map.findReferenceFor(grouped_inputs_outputs[0]); + TensorView* reference2 = + domain_map.findReferenceFor(grouped_inputs_outputs[1]); + NVF_ERROR( + reference1 != nullptr, "Unable to find reference tensor for group 1"); + NVF_ERROR( + reference2 != nullptr, "Unable to find reference tensor for group 2"); + + // Step 1: Tile two transpose dimensions on reference1, merge all other + // dimensions into BIDx, and propagate tiling to the entire DAG. + auto inner_most_id1 = scheduler_utils::innerMostAllocDim(reference1); + auto inner_most_id2 = scheduler_utils::innerMostAllocDim(reference2); + int64_t inner_most_pos1 = + domain_map.getInnerLeafDim(reference1, inner_most_id1); + int64_t inner_most_pos2 = + domain_map.getInnerLeafDim(reference1, inner_most_id2); + NVF_ERROR( + inner_most_pos1 >= 0 && inner_most_pos2 >= 0 && + inner_most_pos1 != inner_most_pos2, + "Invalid inner dim positions for TMA tiling"); + + TensorView* ref_tv = reference1; + if (reference1->isFusionInput() && tparams->use_tma_load) { + // can't propagate due to tma load + auto smem_consumer = ir_utils::consumerTvsOf(reference1).at(0); + auto regs_consumer = ir_utils::consumerTvsOf(smem_consumer).at(0); + ref_tv = regs_consumer; + } else if (reference1->isFusionOutput() && tparams->use_tma_store) { + // can't propagate due to tma store + auto smem_producer = ir_utils::getSoleProducerTv(reference1); + auto regs_producer = ir_utils::getSoleProducerTv(smem_producer); + ref_tv = regs_producer; + } + + // make tile, group2 is swizzled, its inner most dim is tile2 + // [..., I1, .., I2, ...] + ref_tv->split(inner_most_pos1, tparams->tile_size1); + ref_tv->reorder({{inner_most_pos1 + 1, -1}}); + ref_tv->split(inner_most_pos2, tparams->tile_size2); + ref_tv->reorder({{inner_most_pos2 + 1, -1}}); + // [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] + + // Merge all non-tiled dimensions into a single BIDx dim + int64_t rhs_i = ref_tv->nDims() - 3; + for (int64_t lhs_i = ref_tv->nDims() - 4; lhs_i >= 0; lhs_i--) { + if (ref_tv->axis(lhs_i)->isReduction() || + ref_tv->axis(lhs_i)->isDeviceDim()) { + continue; + } + if (ref_tv->axis(rhs_i)->isReduction() || + ref_tv->axis(rhs_i)->isDeviceDim()) { + rhs_i = lhs_i; + continue; + } + ref_tv->merge(lhs_i, rhs_i); + rhs_i = lhs_i; + } + ref_tv->axis(rhs_i)->parallelize(ParallelType::BIDx); + + // Propagate tiling to all TVs + TransformPropagator propagator(ref_tv); + MaxLogicalDomainInfoSpanningTree entire_dag(ref_tv); + entire_dag.traverse(&propagator); + scheduler_utils::parallelizeAllLike( + ref_tv, + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + + // Step 2: Schedule output TMA store. + if (tparams->is_output_smem_transpose) { + // Output smem path: swizzle output shared memory with TMA store. + // Reorder so output inner dim is innermost, then apply TMA swizzle. + MmaInputSmemSwizzle swizzle = + mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0)); + for (auto output_smem_cache : tma_store_tvs) { + mma_utils::scheduleTMAStoreForMmaOutput(output_smem_cache, swizzle); + } + for (auto output : output_tvs) { + mma_utils::scheduleTMAStoreForMmaOutput(output, swizzle); + } + } else if (tparams->use_tma_store) { + // Input smem path with TMA store: Bulk parallel on tile dims. + for (auto output_smem_cache : tma_store_tvs) { + // [.., tile1, tile2] + output_smem_cache->reorder({{-1, -2}}); + // [.., tile2, tile1] + output_smem_cache->setAllocationDomain( + output_smem_cache->getLoopDomain(), true); + } + for (auto output : output_tvs) { + output->axis(-1)->parallelize(ParallelType::Bulk); + output->axis(-2)->parallelize(ParallelType::Bulk); + } + } + + // Step 3: Schedule input shared memory. + if (!tparams->is_output_smem_transpose) { + NVF_ERROR( + tparams->use_tma_load, + "TMA load must be used when input smem is transposed"); + // TMA load and swizzle + for (auto input_smem_cache : tma_load_tvs) { + MmaInputSmemSwizzle swizzle_type = + mma_utils::tmaSwizzleSharedMemory(input_smem_cache); + input_smem_cache->applyMmaSwizzleForTMALoad(swizzle_type); + } + } else if (tparams->use_tma_load) { + // TMA load without swizzle, just contiguous load with Bulk parallel. + // Needs to move tile_1 to inner most for contiguous access. + for (auto input_smem_cache : tma_load_tvs) { + // [.., tile1, tile2] + input_smem_cache->reorder({{-1, -2}}); + // [.., tile2, tile1] + input_smem_cache->axis(-1)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(-2)->parallelize(ParallelType::Bulk); + input_smem_cache->setAllocationDomain( + input_smem_cache->getLoopDomain(), true); + } + } + + // Step 4: Schedule register TVs for per-thread access. + // Tile-2 was swizzled in smem, per-thread access should follow the swizzed + // layout. + // 1. split tile2 by elements_per_chunk defined in swizzle pattern + // 2. further split by chunks_per_thread to get the right granularity for each + // thread + // 3. merge remainings with tile1 and parallelize with TIDx + // [BIDx, tile1, tile2] + ref_tv->split(-1, tparams->elements_per_chunk); + // [BIDx, tile1, tile2/chunk, chunk] + ref_tv->split(-2, tparams->chunks_per_thread); + // [BIDx, tile1, tile2/chunk/cpt, cpt, chunk] + ref_tv->merge(-4, -3); + // [BIDx, tile1/chunk/cpt * tile2, cpt, chunk] + ref_tv->axis(-3)->parallelize(ParallelType::TIDx); + // ref_tv->axis(-1)->parallelize(ParallelType::Unroll); + + // Propagate to all TVs except smem/output TVs managed by TMA + std::unordered_set skip_tvs; + if (!tma_load_tvs.empty()) { + skip_tvs.insert(tma_load_tvs.begin(), tma_load_tvs.end()); + } + if (tparams->use_tma_store) { + skip_tvs.insert(output_tvs.begin(), output_tvs.end()); + } + auto propagate_tvs = ir_utils::allTvsExcept(fusion, skip_tvs); + std::unordered_set propagate_tvs_set( + propagate_tvs.begin(), propagate_tvs.end()); + SetSelector selector(propagate_tvs_set); + MaxLogicalDomainInfoSpanningTree propagate_dag(ref_tv, &selector); + TransformPropagator tp(ref_tv); + propagate_dag.traverse(&tp); + scheduler_utils::parallelizeAllLike( + ref_tv, + propagate_tvs, + {}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + + // Vectorize smem access at the transpose boundary. + if (tparams->is_output_smem_transpose) { + // Vectorize writes to output smem + for (auto output_smem_cache : tma_store_tvs) { + output_smem_cache->axis(-1)->parallelize(ParallelType::Vectorize); + } + } else { + // Vectorize reads from input smem + for (auto tma_load_tv : tma_load_tvs) { + for (auto consumer : ir_utils::consumerTvsOf(tma_load_tv)) { + consumer->axis(-1)->parallelize(ParallelType::Vectorize); + } + } + } + + inlineMost(); } } // namespace tma diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index c68032a7e0b..3d823b1fbee 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1821,4 +1821,80 @@ TEST_F(TransposeTMA, TransposeOutputSmem) { testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); } +// dtype, pair of transpose dimensions, inner most dim of input tensor +using TmaTransposeTestParams = + std::tuple, int64_t>; +class TmaTransposeTestP + : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeTestP, InputTranspose) { + auto dtype = std::get<0>(GetParam()); + auto [dim1, dim2] = std::get<1>(GetParam()); + auto inner_dim = std::get<2>(GetParam()); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, dim1, dim2); + fusion.addOutput(tv1); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({128, 256, inner_dim}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +TEST_P(TmaTransposeTestP, OutputTranspose) { + auto dtype = std::get<0>(GetParam()); + auto [dim1, dim2] = std::get<1>(GetParam()); + auto inner_dim = std::get<2>(GetParam()); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(3, dtype); + auto tv1 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, dim1, dim2); + fusion.addOutput(tv3); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({128, 256, inner_dim}, options); + auto t1 = at::randn({128, 256, inner_dim}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeTestP, + testing::Combine( + testing::Values(DataType::Float, DataType::BFloat16), + testing::Values( + std::make_pair(int64_t(0), int64_t(2)), + std::make_pair(int64_t(1), int64_t(2))), + testing::Values(1, 2, 32, 64, 128, 256, 512)), + [](const testing::TestParamInfo& info) { + auto dtype = std::get<0>(info.param); + auto dims = std::get<1>(info.param); + auto inner_dim = std::get<2>(info.param); + std::ostringstream os; + os << dtype << "_transpose_" << dims.first << "_" << dims.second + << "_size_" << inner_dim; + return os.str(); + }); } // namespace nvfuser From 6963ca5d36ce31851ded7b665e0edc87612cc732 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 27 Feb 2026 07:38:22 -0800 Subject: [PATCH 2/8] minor opt code --- csrc/scheduler/transpose_tma.cpp | 53 ++++++++++++++++---------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index caa72c19327..e74192f9e3a 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -142,34 +142,37 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Set up TMA load for inputs. std::vector tma_load_tvs; - for (auto [cached_input, input_idx] : cached_inputs) { - if (!tparams->use_tma_load) { - continue; - } - auto load_op = dynamic_cast(cached_input->definition()); - if (load_op == nullptr) { - continue; + if (tparams->use_tma_load) { + for (auto [cached_input, input_idx] : cached_inputs) { + auto load_op = dynamic_cast(cached_input->definition()); + if (load_op == nullptr) { + continue; + } + load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); + cached_input->setMemoryType(MemoryType::Shared); + cached_input->cacheAfter(); + tma_load_tvs.push_back(cached_input); } - load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); - cached_input->setMemoryType(MemoryType::Shared); - cached_input->cacheAfter(); - tma_load_tvs.push_back(cached_input); } - // Set up output caching with TMA store when enabled. - std::vector tma_store_tvs; + // Collect output TVs. std::vector output_tvs; + output_tvs.reserve(cached_outputs.size()); for (auto [cached_output, output_idx] : cached_outputs) { - auto output = fusion->outputs()[output_idx]->as(); - output_tvs.push_back(output); - if (!tparams->use_tma_store) { - continue; + output_tvs.push_back(fusion->outputs()[output_idx]->as()); + } + + // Set up output caching with TMA store when enabled. + std::vector tma_store_tvs; + if (tparams->use_tma_store) { + for (auto [cached_output, output_idx] : cached_outputs) { + auto output = fusion->outputs()[output_idx]->as(); + output->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + cached_output->setMemoryType(MemoryType::Shared); + cached_output->cacheBefore(); + tma_store_tvs.push_back(cached_output); } - output->definition()->as()->setOpType( - LoadStoreOpType::CpAsyncBulkTensorTile); - cached_output->setMemoryType(MemoryType::Shared); - cached_output->cacheBefore(); - tma_store_tvs.push_back(cached_output); } // Find transposed ids and positions, two groups, transpose happens in @@ -328,10 +331,8 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // ref_tv->axis(-1)->parallelize(ParallelType::Unroll); // Propagate to all TVs except smem/output TVs managed by TMA - std::unordered_set skip_tvs; - if (!tma_load_tvs.empty()) { - skip_tvs.insert(tma_load_tvs.begin(), tma_load_tvs.end()); - } + std::unordered_set skip_tvs( + tma_load_tvs.begin(), tma_load_tvs.end()); if (tparams->use_tma_store) { skip_tvs.insert(output_tvs.begin(), output_tvs.end()); } From f70775060c5d2dd1a032f380cc8c4eabd14fe4fd Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 2 Mar 2026 11:20:27 -0800 Subject: [PATCH 3/8] fix bank conflict in load --- csrc/scheduler/transpose_tma.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index e74192f9e3a..b7b243c4b00 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -314,21 +314,20 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } // Step 4: Schedule register TVs for per-thread access. - // Tile-2 was swizzled in smem, per-thread access should follow the swizzed - // layout. - // 1. split tile2 by elements_per_chunk defined in swizzle pattern - // 2. further split by chunks_per_thread to get the right granularity for each - // thread - // 3. merge remainings with tile1 and parallelize with TIDx + // + // The merge order is critical for bank conflicts on the non-swizzled smem + // side. By merging tile2_outer as the outer dim and tile1 as the inner dim, + // adjacent threads in a warp access adjacent tile1 positions. Since tile1 is + // the contiguous (inner) dimension of the non-swizzled smem layout, this + // means adjacent threads read from adjacent memory addresses. // [BIDx, tile1, tile2] ref_tv->split(-1, tparams->elements_per_chunk); // [BIDx, tile1, tile2/chunk, chunk] ref_tv->split(-2, tparams->chunks_per_thread); // [BIDx, tile1, tile2/chunk/cpt, cpt, chunk] - ref_tv->merge(-4, -3); - // [BIDx, tile1/chunk/cpt * tile2, cpt, chunk] + ref_tv->merge(-3, -4); + // [BIDx, tile2/chunk/cpt * tile1, cpt, chunk] ref_tv->axis(-3)->parallelize(ParallelType::TIDx); - // ref_tv->axis(-1)->parallelize(ParallelType::Unroll); // Propagate to all TVs except smem/output TVs managed by TMA std::unordered_set skip_tvs( From 985f8d52bacde5c30ab89371f9f8db7da8f9445a Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 2 Mar 2026 11:48:55 -0800 Subject: [PATCH 4/8] vectorize load bf16 from smem to regs --- csrc/scheduler/transpose_tma.cpp | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index b7b243c4b00..853486f62d0 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -314,20 +314,27 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } // Step 4: Schedule register TVs for per-thread access. - // + // ref_tv's inner most tile is tile1 // The merge order is critical for bank conflicts on the non-swizzled smem // side. By merging tile2_outer as the outer dim and tile1 as the inner dim, // adjacent threads in a warp access adjacent tile1 positions. Since tile1 is // the contiguous (inner) dimension of the non-swizzled smem layout, this // means adjacent threads read from adjacent memory addresses. // [BIDx, tile1, tile2] + const int64_t elem_per_bank = 2; ref_tv->split(-1, tparams->elements_per_chunk); // [BIDx, tile1, tile2/chunk, chunk] ref_tv->split(-2, tparams->chunks_per_thread); // [BIDx, tile1, tile2/chunk/cpt, cpt, chunk] ref_tv->merge(-3, -4); // [BIDx, tile2/chunk/cpt * tile1, cpt, chunk] - ref_tv->axis(-3)->parallelize(ParallelType::TIDx); + if (elem_per_bank > 1) { + ref_tv->split(-3, elem_per_bank); + // [BIDx, tile2/chunk/cpt * tile1/2, 2, cpt, chunk] + ref_tv->axis(-4)->parallelize(ParallelType::TIDx); + } else { + ref_tv->axis(-3)->parallelize(ParallelType::TIDx); + } // Propagate to all TVs except smem/output TVs managed by TMA std::unordered_set skip_tvs( @@ -350,18 +357,24 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { /*parallelize_inputs_on_did=*/true); // Vectorize smem access at the transpose boundary. + auto vectorize_smem2regs = [&tma_load_tvs](int pos) { + for (auto tma_load_tv : tma_load_tvs) { + for (auto consumer : ir_utils::consumerTvsOf(tma_load_tv)) { + consumer->axis(pos)->parallelize(ParallelType::Vectorize); + } + } + }; if (tparams->is_output_smem_transpose) { // Vectorize writes to output smem for (auto output_smem_cache : tma_store_tvs) { output_smem_cache->axis(-1)->parallelize(ParallelType::Vectorize); } - } else { - // Vectorize reads from input smem - for (auto tma_load_tv : tma_load_tvs) { - for (auto consumer : ir_utils::consumerTvsOf(tma_load_tv)) { - consumer->axis(-1)->parallelize(ParallelType::Vectorize); - } + // [BIDx, tile2/chunk/cpt * tile1/2, 2, cpt, chunk] + if (elem_per_bank > 1) { + vectorize_smem2regs(-3); } + } else { + vectorize_smem2regs(-1); } inlineMost(); From 2acf28bd45b8d69e1505374763a649b26257992f Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 3 Mar 2026 07:12:02 -0800 Subject: [PATCH 5/8] test bank conflict --- csrc/scheduler/transpose_tma.cpp | 191 +++++++++++++++---------------- tests/cpp/test_transpose.cpp | 68 +++++++++++ 2 files changed, 160 insertions(+), 99 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index 853486f62d0..42e847d860b 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -17,11 +17,14 @@ #include "scheduler/utils.h" #include "transform_replay.h" #include "type.h" + namespace nvfuser { namespace transpose { namespace tma { + constexpr int64_t kBytesPerChunk = 16; constexpr int64_t kTmaSwizzleBytes = 128; + std::unique_ptr getTransposeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, @@ -49,42 +52,36 @@ std::unique_ptr getTransposeHeuristics( } // Choose between input smem transpose and output smem transpose. - // Input smem transpose: transpose happens when reading from swizzled input - // shared memory to registers. columns to rows. - // Output smem transpose: transpose happens when writing from registers to - // swizzled output shared memory. rows to columns. + // Input smem: swizzle applied to input shared memory, transpose happens + // during smem->register reads. + // Output smem: swizzle applied to output shared memory, transpose happens + // during register->smem writes. // - // Use input smem transpose when inputs <= outputs to reduce smem usage and - // swizzle cost (fewer inputs to swizzle). Use output smem transpose when - // inputs > outputs for the same reason (fewer outputs to swizzle). - // - // TMA load/store are independent of the transpose direction: - // - TMA load stages inputs in shared memory (always beneficial for input - // smem transpose since inputs need swizzled smem anyway). - // - TMA store writes outputs from shared memory (always beneficial for - // output smem transpose since outputs need swizzled smem anyway). + // Pick the side with fewer tensors to minimize smem usage and swizzle cost. tparams->is_output_smem_transpose = n_input > n_output; tparams->use_tma_load = true; tparams->use_tma_store = tparams->is_output_smem_transpose; - // Inputs and outputs are grouped into two groups based on their inner most - // dim. The group with smaller number of tvs is swizzled in shared memory. TMA - // scheduler assumes all inputs are in the same group and all outputs are in - // the same group. - // tile_size2 is the tile size for the inner most dim of the group with - // swizzled smem (group2), it should follow the restriction of TMA swizzle - // size. - int64_t swizzle_dtype_size = tparams->is_output_smem_transpose + // Inputs and outputs are grouped by their innermost dim into two groups. + // Group 2 is the swizzled side. tile_size2 is constrained by the TMA + // swizzle size (128 bytes). + int64_t swizzled_dtype_size = tparams->is_output_smem_transpose ? max_output_dtype_size : max_input_dtype_size; - int64_t constrained_tile = kTmaSwizzleBytes / swizzle_dtype_size; + int64_t constrained_tile = kTmaSwizzleBytes / swizzled_dtype_size; tparams->tile_size2 = constrained_tile; - // Fixed, 16 bytes per chunk for swizzle, 4 float32 or 8 float16 elements. - tparams->elements_per_chunk = kBytesPerChunk / swizzle_dtype_size; + // 16 bytes per chunk: 4 float32 or 8 bfloat16 elements. + tparams->elements_per_chunk = kBytesPerChunk / swizzled_dtype_size; + + // Vectorize along tile1 (the non-swizzled dim) to align with the 4-byte + // smem bank width. For bf16 (2 bytes), this groups 2 elements per thread; + // for float (4 bytes), vectorize_factor1 = 1 (already bank-aligned). + tparams->vectorize_factor1 = + scheduler_utils::safeDiv(4L, swizzled_dtype_size); - // free dim that can be tuned, increase tile size when input count is small - // assuming issue 64KB loading data per sm, 256 threads per cta. + // Heuristic for tile_size1 (the non-swizzled, tunable dim). + // Target: 64KB of data loaded per SM, 256 threads per CTA. auto dev_props = at::cuda::getCurrentDeviceProperties(); constexpr int64_t bytes_per_sm = 64 * 1024; constexpr int64_t threads_per_cta = 256; @@ -94,14 +91,15 @@ std::unique_ptr getTransposeHeuristics( const int64_t bytes_per_tile = bytes_per_cta / n_input; int64_t estimated_tile_size1 = bytes_per_tile / kTmaSwizzleBytes; - // tile1 * tile2 = elements_per_chunk * chunks_per_thread * threads_per_cta - // May further increase tile1 to ensure each thread has at least 4 chunks to - // process for better efficiency. + // Ensure each thread processes at least min_chunks_per_thread chunks. + // tile1 * tile2 = elements_per_chunk * chunks_per_thread * + // vectorize_factor1 * threads_per_cta constexpr int64_t min_chunks_per_thread = 4; auto get_chunks_per_thread = [&]() { int64_t elements_per_thread = estimated_tile_size1 * tparams->tile_size2 / threads_per_cta; - return elements_per_thread / tparams->elements_per_chunk; + return elements_per_thread / tparams->elements_per_chunk / + tparams->vectorize_factor1; }; while (get_chunks_per_thread() < min_chunks_per_thread) { estimated_tile_size1 *= 2; @@ -140,7 +138,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { scheduler_utils::prepareForMemoryTypePromotion(fusion); - // Set up TMA load for inputs. + // Set up TMA load: input -> smem_cache (TMA) -> reg_cache std::vector tma_load_tvs; if (tparams->use_tma_load) { for (auto [cached_input, input_idx] : cached_inputs) { @@ -155,14 +153,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } } - // Collect output TVs. + // Collect global output TVs (needed for TMA store scheduling). std::vector output_tvs; output_tvs.reserve(cached_outputs.size()); for (auto [cached_output, output_idx] : cached_outputs) { output_tvs.push_back(fusion->outputs()[output_idx]->as()); } - // Set up output caching with TMA store when enabled. + // Set up TMA store: reg_cache -> smem_cache (TMA) -> output std::vector tma_store_tvs; if (tparams->use_tma_store) { for (auto [cached_output, output_idx] : cached_outputs) { @@ -175,63 +173,59 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } } - // Find transposed ids and positions, two groups, transpose happens in - // group-2's cached smem. + // Group tensors by innermost dim. Group 1 = non-swizzled, Group 2 = swizzled. scheduler_tools::TransposeDomainMap domain_map(fusion); auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); NVF_ERROR(grouped_inputs_outputs.size() >= 2); - // When there are more inputs than outputs, output smem transpose should be - // used, however, if it is not, then input smem tranpose will be used, to - // ensure group2 is always the one that is transposed, we should swap group1 - // and group2. + // When not using output smem transpose but inputs > outputs, swap groups + // so group 2 remains the swizzled side. if (!tparams->is_output_smem_transpose && cached_inputs.size() > cached_outputs.size()) { std::swap(grouped_inputs_outputs[0], grouped_inputs_outputs[1]); } - TensorView* reference1 = + TensorView* group1_ref = domain_map.findReferenceFor(grouped_inputs_outputs[0]); - TensorView* reference2 = + TensorView* group2_ref = domain_map.findReferenceFor(grouped_inputs_outputs[1]); NVF_ERROR( - reference1 != nullptr, "Unable to find reference tensor for group 1"); + group1_ref != nullptr, "Unable to find reference tensor for group 1"); NVF_ERROR( - reference2 != nullptr, "Unable to find reference tensor for group 2"); + group2_ref != nullptr, "Unable to find reference tensor for group 2"); - // Step 1: Tile two transpose dimensions on reference1, merge all other + // Step 1: Tile two transpose dimensions on group1_ref, merge all other // dimensions into BIDx, and propagate tiling to the entire DAG. - auto inner_most_id1 = scheduler_utils::innerMostAllocDim(reference1); - auto inner_most_id2 = scheduler_utils::innerMostAllocDim(reference2); - int64_t inner_most_pos1 = - domain_map.getInnerLeafDim(reference1, inner_most_id1); - int64_t inner_most_pos2 = - domain_map.getInnerLeafDim(reference1, inner_most_id2); + auto group1_inner_id = scheduler_utils::innerMostAllocDim(group1_ref); + auto group2_inner_id = scheduler_utils::innerMostAllocDim(group2_ref); + int64_t group1_inner_pos = + domain_map.getInnerLeafDim(group1_ref, group1_inner_id); + int64_t group2_inner_pos = + domain_map.getInnerLeafDim(group1_ref, group2_inner_id); NVF_ERROR( - inner_most_pos1 >= 0 && inner_most_pos2 >= 0 && - inner_most_pos1 != inner_most_pos2, + group1_inner_pos >= 0 && group2_inner_pos >= 0 && + group1_inner_pos != group2_inner_pos, "Invalid inner dim positions for TMA tiling"); - TensorView* ref_tv = reference1; - if (reference1->isFusionInput() && tparams->use_tma_load) { - // can't propagate due to tma load - auto smem_consumer = ir_utils::consumerTvsOf(reference1).at(0); + // Transform propagation can't pass through TMA tvs, select + // a tv after tma load or before tma store. + TensorView* ref_tv = group1_ref; + if (group1_ref->isFusionInput() && tparams->use_tma_load) { + auto smem_consumer = ir_utils::consumerTvsOf(group1_ref).at(0); auto regs_consumer = ir_utils::consumerTvsOf(smem_consumer).at(0); ref_tv = regs_consumer; - } else if (reference1->isFusionOutput() && tparams->use_tma_store) { - // can't propagate due to tma store - auto smem_producer = ir_utils::getSoleProducerTv(reference1); + } else if (group1_ref->isFusionOutput() && tparams->use_tma_store) { + auto smem_producer = ir_utils::getSoleProducerTv(group1_ref); auto regs_producer = ir_utils::getSoleProducerTv(smem_producer); ref_tv = regs_producer; } - // make tile, group2 is swizzled, its inner most dim is tile2 - // [..., I1, .., I2, ...] - ref_tv->split(inner_most_pos1, tparams->tile_size1); - ref_tv->reorder({{inner_most_pos1 + 1, -1}}); - ref_tv->split(inner_most_pos2, tparams->tile_size2); - ref_tv->reorder({{inner_most_pos2 + 1, -1}}); - // [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] + // Split and reorder to create tiles: + // [..., I1, .., I2, ...] → [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] + ref_tv->split(group1_inner_pos, tparams->tile_size1); + ref_tv->reorder({{group1_inner_pos + 1, -1}}); + ref_tv->split(group2_inner_pos, tparams->tile_size2); + ref_tv->reorder({{group2_inner_pos + 1, -1}}); // Merge all non-tiled dimensions into a single BIDx dim int64_t rhs_i = ref_tv->nDims() - 3; @@ -251,9 +245,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { ref_tv->axis(rhs_i)->parallelize(ParallelType::BIDx); // Propagate tiling to all TVs - TransformPropagator propagator(ref_tv); + TransformPropagator tiling_propagator(ref_tv); MaxLogicalDomainInfoSpanningTree entire_dag(ref_tv); - entire_dag.traverse(&propagator); + entire_dag.traverse(&tiling_propagator); scheduler_utils::parallelizeAllLike( ref_tv, /*selected_tvs=*/{}, @@ -263,8 +257,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Step 2: Schedule output TMA store. if (tparams->is_output_smem_transpose) { - // Output smem path: swizzle output shared memory with TMA store. - // Reorder so output inner dim is innermost, then apply TMA swizzle. + // Output smem path: apply TMA swizzle to output shared memory. MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0)); for (auto output_smem_cache : tma_store_tvs) { @@ -274,11 +267,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { mma_utils::scheduleTMAStoreForMmaOutput(output, swizzle); } } else if (tparams->use_tma_store) { - // Input smem path with TMA store: Bulk parallel on tile dims. + // Input smem path with TMA store: reorder for contiguous output and + // set Bulk parallel on tile dims. for (auto output_smem_cache : tma_store_tvs) { - // [.., tile1, tile2] + // [.., tile1, tile2] → [.., tile2, tile1] output_smem_cache->reorder({{-1, -2}}); - // [.., tile2, tile1] output_smem_cache->setAllocationDomain( output_smem_cache->getLoopDomain(), true); } @@ -290,22 +283,21 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Step 3: Schedule input shared memory. if (!tparams->is_output_smem_transpose) { + // Input smem path: TMA load into swizzled shared memory. NVF_ERROR( tparams->use_tma_load, "TMA load must be used when input smem is transposed"); - // TMA load and swizzle for (auto input_smem_cache : tma_load_tvs) { MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(input_smem_cache); input_smem_cache->applyMmaSwizzleForTMALoad(swizzle_type); } } else if (tparams->use_tma_load) { - // TMA load without swizzle, just contiguous load with Bulk parallel. - // Needs to move tile_1 to inner most for contiguous access. + // Output smem path: TMA load without swizzle. Reorder so tile1 + // (group 1's inner dim) is innermost for contiguous access. for (auto input_smem_cache : tma_load_tvs) { - // [.., tile1, tile2] + // [.., tile1, tile2] → [.., tile2, tile1] input_smem_cache->reorder({{-1, -2}}); - // [.., tile2, tile1] input_smem_cache->axis(-1)->parallelize(ParallelType::Bulk); input_smem_cache->axis(-2)->parallelize(ParallelType::Bulk); input_smem_cache->setAllocationDomain( @@ -314,50 +306,49 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } // Step 4: Schedule register TVs for per-thread access. - // ref_tv's inner most tile is tile1 + // ref_tv's innermost tile is tile1 (group 1's inner dim). // The merge order is critical for bank conflicts on the non-swizzled smem // side. By merging tile2_outer as the outer dim and tile1 as the inner dim, // adjacent threads in a warp access adjacent tile1 positions. Since tile1 is // the contiguous (inner) dimension of the non-swizzled smem layout, this // means adjacent threads read from adjacent memory addresses. // [BIDx, tile1, tile2] - const int64_t elem_per_bank = 2; ref_tv->split(-1, tparams->elements_per_chunk); // [BIDx, tile1, tile2/chunk, chunk] ref_tv->split(-2, tparams->chunks_per_thread); // [BIDx, tile1, tile2/chunk/cpt, cpt, chunk] ref_tv->merge(-3, -4); // [BIDx, tile2/chunk/cpt * tile1, cpt, chunk] - if (elem_per_bank > 1) { - ref_tv->split(-3, elem_per_bank); - // [BIDx, tile2/chunk/cpt * tile1/2, 2, cpt, chunk] + if (tparams->vectorize_factor1 > 1) { + ref_tv->split(-3, tparams->vectorize_factor1); + // [BIDx, tile2/chunk/cpt * tile1/vec, vec, cpt, chunk] ref_tv->axis(-4)->parallelize(ParallelType::TIDx); } else { ref_tv->axis(-3)->parallelize(ParallelType::TIDx); } - // Propagate to all TVs except smem/output TVs managed by TMA + // Propagate register scheduling to all TVs except TMA smem/output TVs. std::unordered_set skip_tvs( tma_load_tvs.begin(), tma_load_tvs.end()); if (tparams->use_tma_store) { skip_tvs.insert(output_tvs.begin(), output_tvs.end()); } - auto propagate_tvs = ir_utils::allTvsExcept(fusion, skip_tvs); - std::unordered_set propagate_tvs_set( - propagate_tvs.begin(), propagate_tvs.end()); - SetSelector selector(propagate_tvs_set); - MaxLogicalDomainInfoSpanningTree propagate_dag(ref_tv, &selector); - TransformPropagator tp(ref_tv); - propagate_dag.traverse(&tp); + auto reg_tvs = ir_utils::allTvsExcept(fusion, skip_tvs); + std::unordered_set reg_tvs_set(reg_tvs.begin(), reg_tvs.end()); + SetSelector selector(reg_tvs_set); + MaxLogicalDomainInfoSpanningTree reg_dag(ref_tv, &selector); + TransformPropagator reg_propagator(ref_tv); + reg_dag.traverse(®_propagator); scheduler_utils::parallelizeAllLike( ref_tv, - propagate_tvs, + reg_tvs, {}, /*propagate_padding=*/true, /*parallelize_inputs_on_did=*/true); - // Vectorize smem access at the transpose boundary. - auto vectorize_smem2regs = [&tma_load_tvs](int pos) { + // Vectorize smem reads at the transpose boundary: each consumer of + // a TMA-loaded smem TV reads with vectorized access. + auto vectorize_smem_reads = [&tma_load_tvs](int pos) { for (auto tma_load_tv : tma_load_tvs) { for (auto consumer : ir_utils::consumerTvsOf(tma_load_tv)) { consumer->axis(pos)->parallelize(ParallelType::Vectorize); @@ -365,16 +356,18 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } }; if (tparams->is_output_smem_transpose) { - // Vectorize writes to output smem + // Vectorize writes to swizzled output smem (chunk dim). for (auto output_smem_cache : tma_store_tvs) { output_smem_cache->axis(-1)->parallelize(ParallelType::Vectorize); } - // [BIDx, tile2/chunk/cpt * tile1/2, 2, cpt, chunk] - if (elem_per_bank > 1) { - vectorize_smem2regs(-3); + // Vectorize reads from non-swizzled input smem (vec dim). + // [BIDx, tile2/chunk/cpt * tile1/vec, vec, cpt, chunk] + if (tparams->vectorize_factor1 > 1) { + vectorize_smem_reads(-3); } } else { - vectorize_smem2regs(-1); + // Input smem path: vectorize reads from swizzled input smem (chunk dim). + vectorize_smem_reads(-1); } inlineMost(); diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 3d823b1fbee..4ea01df9bf2 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -9,6 +9,7 @@ #include #include +#include "device_lower/analysis/bank_conflict.h" #include "exceptions.h" #include "ops/all_ops.h" #include "optimization_pass.h" @@ -1897,4 +1898,71 @@ INSTANTIATE_TEST_SUITE_P( << "_size_" << inner_dim; return os.str(); }); + +class TmaTransposeDtypeP : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeDtypeP, OutputTransposeBankconflict) { + auto dtype = GetParam(); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(2, dtype); + auto tv1 = makeContigTensor(2, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + tv0 = maybeCastOp(DataType::Float, tv0); + tv1 = maybeCastOp(DataType::Float, tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + tv4 = maybeCastOp(dtype, tv4); + fusion.addOutput(tv4); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({16384, 8192}, options); + auto t1 = at::randn({16384, 8192}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); + + // Check bank conflicts via the compiled kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + for (auto& executor : runtime->executors()) { + if (auto* ke = dynamic_cast(executor.get())) { + auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel()); + for (auto& [expr, ways] : bank_conflicts) { + auto [read_ways, write_ways] = ways; + std::cout << " Bank conflict: " << expr->toString() + << " read=" << read_ways << "-way" + << ", write=" << write_ways << "-way" << std::endl; + } + if (dtype == DataType::Float) { + EXPECT_TRUE(bank_conflicts.empty()); + } else { + // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved. + EXPECT_FALSE(bank_conflicts.empty()); + } + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeDtypeP, + testing::Values(DataType::Float, DataType::BFloat16), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << info.param; + return os.str(); + }); + } // namespace nvfuser From 0f793c5c2bb94bea2e5b0efb5ae420d127bbee6b Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 3 Mar 2026 07:41:45 -0800 Subject: [PATCH 6/8] test different combinations of tma load/store --- tests/cpp/test_transpose.cpp | 81 ++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 4ea01df9bf2..1da057f55fd 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -19,6 +19,7 @@ #include "runtime/fusion_executor_cache.h" #include "scheduler/all_schedulers.h" #include "scheduler/mma_utils.h" +#include "scheduler/runtime_info.h" #include "scheduler/tools/inlining.h" #include "scheduler/transpose.h" #include "scheduler/utils.h" @@ -1965,4 +1966,84 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +// Test different combinations of TMA transpose parameters: +// (is_output_smem, use_tma_load, use_tma_store) +// (false, true, false) - input smem tranapose, TMA load only +// (false, true, true) - input smem tranapose, TMA load + TMA store +// (true, true, true) - output smem tranapose, TMA load + TMA store +// (true, false, true) - output smem tranapose, TMA store only +using TmaTransposeParamsTestP_Params = + std::tuple; // is_output_smem, tma_load, tma_store + +class TmaTransposeParamsTestP + : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeParamsTestP, TmaTransposeParams) { + auto [is_output_smem, use_tma_load, use_tma_store] = GetParam(); + + // Build fusion: two inputs -> add -> transpose -> square -> output. + // Two inputs ensure the default heuristic picks output smem transpose. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({4096, 2048}, options); + auto t1 = at::randn({4096, 2048}, options); + + // Compute default TMA heuristics, then override the 3 params. + SchedulerRuntimeInfo runtime_info(fusion.get(), {t0, t1}); + auto scheduler = + SchedulerEntry::makeSchedulerInstance(SchedulerType::Transpose); + auto heuristic_params = + scheduler->computeHeuristics(fusion.get(), runtime_info); + auto* tparams = heuristic_params->as(); + tparams->is_output_smem_transpose = is_output_smem; + tparams->use_tma_load = use_tma_load; + tparams->use_tma_store = use_tma_store; + + // Schedule, compile, and run. + scheduler->schedule(fusion.get(), tparams); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0, t1}, tparams->lparams); + auto outputs = ke.run({t0, t1}, {}, tparams->lparams); + testValidate(fusion.get(), outputs, {t0, t1}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeParamsTestP, + testing::Values( + // (is_output_smem, tma_load, tma_store) + std::make_tuple(false, true, false), // input smem, TMA load only + std::make_tuple(false, true, true), // input smem, TMA load + store + std::make_tuple(true, true, true), // output smem, TMA load + store + std::make_tuple(true, false, true)), // output smem, TMA store only + [](const testing::TestParamInfo& info) { + bool is_output_smem = std::get<0>(info.param); + bool tma_load = std::get<1>(info.param); + bool tma_store = std::get<2>(info.param); + std::ostringstream os; + os << (is_output_smem ? "output_smem_transpose" : "input_smem_transpose") + << "_load_" << (tma_load ? "tma" : "off") << "_store_" + << (tma_store ? "tma" : "off"); + return os.str(); + }); + } // namespace nvfuser From bc772dbf243c33e5f296a80a090413e8fa70c434 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 3 Mar 2026 11:09:17 -0800 Subject: [PATCH 7/8] use vect1 = 1 to avoid bank conflict --- csrc/scheduler/transpose_tma.cpp | 5 ++--- tests/cpp/test_transpose.cpp | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index 42e847d860b..b4927c4ce9b 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -76,9 +76,8 @@ std::unique_ptr getTransposeHeuristics( // Vectorize along tile1 (the non-swizzled dim) to align with the 4-byte // smem bank width. For bf16 (2 bytes), this groups 2 elements per thread; - // for float (4 bytes), vectorize_factor1 = 1 (already bank-aligned). - tparams->vectorize_factor1 = - scheduler_utils::safeDiv(4L, swizzled_dtype_size); + // However, it leads to 2-way bank conflict in regs -> smem. Disable for now. + tparams->vectorize_factor1 = 1; // Heuristic for tile_size1 (the non-swizzled, tunable dim). // Target: 64KB of data loaded per SM, 256 threads per CTA. diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 1da057f55fd..f7a3e3d8273 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1946,12 +1946,7 @@ TEST_P(TmaTransposeDtypeP, OutputTransposeBankconflict) { << " read=" << read_ways << "-way" << ", write=" << write_ways << "-way" << std::endl; } - if (dtype == DataType::Float) { - EXPECT_TRUE(bank_conflicts.empty()); - } else { - // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved. - EXPECT_FALSE(bank_conflicts.empty()); - } + EXPECT_TRUE(bank_conflicts.empty()); } } } From f722d67d2f91faf623bd2f216b8671757528568d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 5 Mar 2026 08:24:47 -0800 Subject: [PATCH 8/8] format --- csrc/scheduler/transpose_tma.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index b4927c4ce9b..5f75f5a8aad 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -36,18 +36,16 @@ std::unique_ptr getTransposeHeuristics( int64_t max_input_dtype_size = 1; int64_t n_input = 0; for (auto inp : ir_utils::filterByType(fusion->inputs())) { - max_input_dtype_size = std::max( - max_input_dtype_size, - dataTypeSizeByte(valueOrError(inp->getDataType()))); + max_input_dtype_size = + std::max(max_input_dtype_size, dataTypeSizeByte(inp->getDataType())); n_input++; } int64_t max_output_dtype_size = 1; int64_t n_output = 0; for (auto out : ir_utils::filterByType(fusion->outputs())) { - max_output_dtype_size = std::max( - max_output_dtype_size, - dataTypeSizeByte(valueOrError(out->getDataType()))); + max_output_dtype_size = + std::max(max_output_dtype_size, dataTypeSizeByte(out->getDataType())); n_output++; } @@ -82,7 +80,7 @@ std::unique_ptr getTransposeHeuristics( // Heuristic for tile_size1 (the non-swizzled, tunable dim). // Target: 64KB of data loaded per SM, 256 threads per CTA. auto dev_props = at::cuda::getCurrentDeviceProperties(); - constexpr int64_t bytes_per_sm = 64 * 1024; + constexpr int64_t bytes_per_sm = 64L * 1024L; constexpr int64_t threads_per_cta = 256; const int64_t cta_per_sm = dev_props->maxThreadsPerMultiProcessor / threads_per_cta;