diff --git a/csrc/device_lower/analysis/tma.cpp b/csrc/device_lower/analysis/tma.cpp index 3dedd9413f1..b1d45fa2a63 100644 --- a/csrc/device_lower/analysis/tma.cpp +++ b/csrc/device_lower/analysis/tma.cpp @@ -53,13 +53,15 @@ std::unordered_set getBatchableTmaLoads( // We have some tests where TMA load is used in an untraditional way. // e.g. parallelized with threads, serial load, which requires multiple // mbarriers or reuse of the same mbarrier. - if (std::any_of( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - [](const IterDomain* id) { - return id->isThreadDim() || - id->getParallelType() == ParallelType::Serial; - })) { + auto non_trivial_ids = + tv->getLoopDomain() | std::views::filter([](const IterDomain* id) { + return !id->extent()->isConstScalar() || + id->extent()->evaluate().as() > 1; + }); + if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) { + return id->isThreadDim() || + id->getParallelType() == ParallelType::Serial; + })) { return {}; } non_cb_tma_load_exprs.push_back(expr); diff --git a/csrc/options.cpp b/csrc/options.cpp index 14dddd89eec..775f76f26ea 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -40,7 +40,7 @@ auto parseEnvOptions( available_options.end(), std::back_inserter(option_values), [](const auto& kv) { return kv.first; }); - std::sort(option_values.begin(), option_values.end()); + std::ranges::sort(option_values); NVF_CHECK( false, "Parsing ", @@ -174,6 +174,7 @@ const std::unordered_map& getEnableOptions() { {"tma_pointwise", EnableOption::TmaPointwise}, {"tma_inner_persistent", EnableOption::TmaInnerPersistent}, {"tma_reduction", EnableOption::TmaReduction}, + {"tma_transpose", EnableOption::TmaTranspose}, {"ws_normalization", EnableOption::WarpSpecializedNormalization}, {"host_ir_lowering", EnableOption::HostIrLowering}, {"host_ir_jit", EnableOption::HostIrJit}, diff --git a/csrc/options.h b/csrc/options.h index 6dad3909997..9a9063c2f3b 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -22,7 +23,7 @@ namespace nvfuser { //! //! These can be set through the `NVFUSER_DUMP` environment variable //! -enum class DebugDumpOption { +enum class DebugDumpOption : std::uint8_t { CutlassCompile, //!< Dump compile commands and compile times for //!< CutlassExecutor FunctionTrace, //!< Dump the function trace of selected internal function. The @@ -97,7 +98,7 @@ enum class DebugDumpOption { //! //! These can be set through the `NVFUSER_ENABLE` environment variable //! -enum class EnableOption { +enum class EnableOption : std::uint8_t { CutlassScheduler, //! Enable the CUTLASS scheduler and executor FuseMatmul, //! Enable automatic fusion of matmul and linear ops FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel @@ -118,6 +119,7 @@ enum class EnableOption { TmaPointwise, //! Enable TMA pointwise kernel TmaInnerPersistent, //! Enable TMA inner persistent kernel TmaReduction, //! Enable TMA reduction kernel + TmaTranspose, //! Enable TMA transpose kernel WarpSpecializedNormalization, //! Enable warp specialized persistent kernel HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR HostIrJit, //! Enable Host IR JIT compilation with LLVM @@ -134,7 +136,7 @@ enum class EnableOption { //! //! These can be set through the `NVFUSER_DISABLE` environment variable //! -enum class DisableOption { +enum class DisableOption : std::uint8_t { CompileToSass, //! Disable direct compilation to sass so the ptx can be //! examined ContigIndexing, //! Disable contiguous indexing @@ -176,7 +178,7 @@ enum class DisableOption { //! //! These can be set through the `NVFUSER_PROF` environment variable //! -enum class ProfilerOption { +enum class ProfilerOption : std::uint8_t { Enable, //! Enables the profiler. EnableNocupti, //! Enables the profiler, but disables CUPTI specific //! profiling inorder to measure true host time without @@ -197,10 +199,11 @@ class Options { public: Options() : options_(getOptionsFromEnv()) {} - Options(const Options& other) { - std::lock_guard lock_other(other.mutex_); - options_ = other.options_; - } + Options(const Options& other) + : options_([&other]() { + std::lock_guard lock_other(other.mutex_); + return other.options_; + }()) {} Options& operator=(const Options& other) { std::lock_guard lock_other(other.mutex_); diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 8dacde8d3b1..5487d59493a 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -407,9 +407,11 @@ std::unique_ptr TransposeScheduler::computeHeuristics( std::unique_ptr tparams = nullptr; - // Try TMA path first - tparams = - transpose::tma::getTransposeHeuristics(fusion, runtime_info, data_cache); + // Try TMA path first if enabled + if (isOptionEnabled(EnableOption::TmaTranspose)) { + tparams = transpose::tma::getTransposeHeuristics( + fusion, runtime_info, data_cache); + } // Fallback to non-TMA scheduler if TMA is not applicable if (tparams == nullptr) { @@ -431,7 +433,7 @@ void TransposeScheduler::schedule( "Incorrect parameters sent to TransposeScheduler::schedule", params); - if (tparams->use_tma_load) { + if (tparams->use_tma_load || tparams->use_tma_store) { transpose::tma::scheduleTranspose(fusion, tparams); } else { transpose::non_tma::scheduleTranspose(fusion, tparams); diff --git a/csrc/scheduler/transpose_heuristic.h b/csrc/scheduler/transpose_heuristic.h index 96db01a6410..838f5df2d25 100644 --- a/csrc/scheduler/transpose_heuristic.h +++ b/csrc/scheduler/transpose_heuristic.h @@ -22,7 +22,7 @@ namespace nvfuser { // are equivelent! class TransposeParams : public HeuristicParams { public: - TransposeParams() : HeuristicParams(SchedulerType::Transpose) {}; + TransposeParams() : HeuristicParams(SchedulerType::Transpose){}; static constexpr int64_t getMaxThreadsPerBlock() { return 128; } @@ -39,6 +39,20 @@ class TransposeParams : public HeuristicParams { // Whether to use TMA for loading inputs bool use_tma_load = false; + bool use_tma_store = false; + + // Which side of shared memory holds the transposed (swizzled) layout. + // false: input smem is swizzled, transpose happens on smem->register read. + // true: output smem is swizzled, transpose happens on register->smem write. + // This is independent of use_tma_load/use_tma_store — TMA can be used for + // either side regardless of where the transpose swizzle lives. + bool is_output_smem_transpose = false; + + // In 128-bytes swizzled tma load, inner most dim is split into 8 chunks each + // with 16 bytes. Each thread may handle multiple chunks along the inner most + // dim. + int64_t chunks_per_thread = 1; + int64_t elements_per_chunk = 1; // Vectorization factor for tensors in the first group int64_t vectorize_factor1 = 1; @@ -65,6 +79,10 @@ class TransposeParams : public HeuristicParams { } bool attr_equal = other->cparams == cparams && other->use_tma_load == use_tma_load && + other->use_tma_store == use_tma_store && + other->is_output_smem_transpose == is_output_smem_transpose && + other->chunks_per_thread == chunks_per_thread && + other->elements_per_chunk == elements_per_chunk && other->split_before_tiling == split_before_tiling && other->dims_merged_with_1 == dims_merged_with_1 && other->dims_merged_with_2 == dims_merged_with_2 && @@ -99,6 +117,14 @@ class TransposeParams : public HeuristicParams { if (unroll_factor2 > 1) { ss << "Unroll group 2, Factor: " << unroll_factor2 << "\n"; } + if (use_tma_load || use_tma_store) { + ss << "TMA: load=" << (use_tma_load ? "true" : "false") + << " store=" << (use_tma_store ? "true" : "false") + << " is_output_smem_transpose=" + << (is_output_smem_transpose ? "true" : "false") + << " chunks_per_thread=" << chunks_per_thread + << " elements_per_chunk=" << elements_per_chunk << "\n"; + } if (!split_before_tiling.empty() || !dims_merged_with_1.empty() || !dims_merged_with_2.empty()) { ss << "Virtual inner-most dim:\n"; @@ -146,6 +172,10 @@ class TransposeParams : public HeuristicParams { size_t hash() const override { return c10::get_hash( use_tma_load, + use_tma_store, + is_output_smem_transpose, + chunks_per_thread, + elements_per_chunk, split_before_tiling, dims_merged_with_1, dims_merged_with_2, diff --git a/csrc/scheduler/transpose_tma.cpp b/csrc/scheduler/transpose_tma.cpp index d4f2a64ca40..5f75f5a8aad 100644 --- a/csrc/scheduler/transpose_tma.cpp +++ b/csrc/scheduler/transpose_tma.cpp @@ -8,20 +8,366 @@ #include "scheduler/transpose_tma.h" +#include +#include "ir/utils.h" +#include "scheduler/mma_utils.h" +#include "scheduler/runtime_info.h" +#include "scheduler/tools/domain_map.h" +#include "scheduler/tools/inlining.h" +#include "scheduler/utils.h" +#include "transform_replay.h" +#include "type.h" + namespace nvfuser { namespace transpose { namespace tma { +constexpr int64_t kBytesPerChunk = 16; +constexpr int64_t kTmaSwizzleBytes = 128; + std::unique_ptr getTransposeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicDataCache* data_cache) { - // TMA transpose scheduling is not yet implemented. - return nullptr; + auto tparams = std::make_unique(); + tparams->tag = "TMA Transpose heuristics"; + tparams->cparams.index_type = runtime_info.getIndexType(); + + int64_t max_input_dtype_size = 1; + int64_t n_input = 0; + for (auto inp : ir_utils::filterByType(fusion->inputs())) { + max_input_dtype_size = + std::max(max_input_dtype_size, dataTypeSizeByte(inp->getDataType())); + n_input++; + } + + int64_t max_output_dtype_size = 1; + int64_t n_output = 0; + for (auto out : ir_utils::filterByType(fusion->outputs())) { + max_output_dtype_size = + std::max(max_output_dtype_size, dataTypeSizeByte(out->getDataType())); + n_output++; + } + + // Choose between input smem transpose and output smem transpose. + // Input smem: swizzle applied to input shared memory, transpose happens + // during smem->register reads. + // Output smem: swizzle applied to output shared memory, transpose happens + // during register->smem writes. + // + // Pick the side with fewer tensors to minimize smem usage and swizzle cost. + tparams->is_output_smem_transpose = n_input > n_output; + tparams->use_tma_load = true; + tparams->use_tma_store = tparams->is_output_smem_transpose; + + // Inputs and outputs are grouped by their innermost dim into two groups. + // Group 2 is the swizzled side. tile_size2 is constrained by the TMA + // swizzle size (128 bytes). + int64_t swizzled_dtype_size = tparams->is_output_smem_transpose + ? max_output_dtype_size + : max_input_dtype_size; + int64_t constrained_tile = kTmaSwizzleBytes / swizzled_dtype_size; + tparams->tile_size2 = constrained_tile; + + // 16 bytes per chunk: 4 float32 or 8 bfloat16 elements. + tparams->elements_per_chunk = kBytesPerChunk / swizzled_dtype_size; + + // Vectorize along tile1 (the non-swizzled dim) to align with the 4-byte + // smem bank width. For bf16 (2 bytes), this groups 2 elements per thread; + // However, it leads to 2-way bank conflict in regs -> smem. Disable for now. + tparams->vectorize_factor1 = 1; + + // Heuristic for tile_size1 (the non-swizzled, tunable dim). + // Target: 64KB of data loaded per SM, 256 threads per CTA. + auto dev_props = at::cuda::getCurrentDeviceProperties(); + constexpr int64_t bytes_per_sm = 64L * 1024L; + constexpr int64_t threads_per_cta = 256; + const int64_t cta_per_sm = + dev_props->maxThreadsPerMultiProcessor / threads_per_cta; + const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; + const int64_t bytes_per_tile = bytes_per_cta / n_input; + int64_t estimated_tile_size1 = bytes_per_tile / kTmaSwizzleBytes; + + // Ensure each thread processes at least min_chunks_per_thread chunks. + // tile1 * tile2 = elements_per_chunk * chunks_per_thread * + // vectorize_factor1 * threads_per_cta + constexpr int64_t min_chunks_per_thread = 4; + auto get_chunks_per_thread = [&]() { + int64_t elements_per_thread = + estimated_tile_size1 * tparams->tile_size2 / threads_per_cta; + return elements_per_thread / tparams->elements_per_chunk / + tparams->vectorize_factor1; + }; + while (get_chunks_per_thread() < min_chunks_per_thread) { + estimated_tile_size1 *= 2; + } + tparams->tile_size1 = estimated_tile_size1; + tparams->chunks_per_thread = get_chunks_per_thread(); + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== TMA Transpose Stats ========\n" + << "inputs: " << ir_utils::toString(fusion->inputs()) << "\n" + << "outputs: " << ir_utils::toString(fusion->outputs()) << "\n" + << "is_output_smem_transpose: " << tparams->is_output_smem_transpose + << "\n" + << "use_tma_load: " << tparams->use_tma_load << "\n" + << "use_tma_store: " << tparams->use_tma_store << "\n" + << "tile_size1: " << tparams->tile_size1 << "\n" + << "tile_size2: " << tparams->tile_size2 << "\n" + << "chunks_per_thread: " << tparams->chunks_per_thread << "\n" + << "elements_per_chunk: " << tparams->elements_per_chunk << "\n" + << "\n"; + } + return tparams; } void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { - NVF_THROW("TMA transpose scheduling is not yet implemented."); + FusionGuard fg(fusion); + + scheduler_utils::clearMemorySpace(fusion); + + NVF_ERROR( + !ir_utils::hasAnyReductionOps(fusion), + "This scheduler only handles pointwise ops."); + + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + + scheduler_utils::prepareForMemoryTypePromotion(fusion); + + // Set up TMA load: input -> smem_cache (TMA) -> reg_cache + std::vector tma_load_tvs; + if (tparams->use_tma_load) { + for (auto [cached_input, input_idx] : cached_inputs) { + auto load_op = dynamic_cast(cached_input->definition()); + if (load_op == nullptr) { + continue; + } + load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); + cached_input->setMemoryType(MemoryType::Shared); + cached_input->cacheAfter(); + tma_load_tvs.push_back(cached_input); + } + } + + // Collect global output TVs (needed for TMA store scheduling). + std::vector output_tvs; + output_tvs.reserve(cached_outputs.size()); + for (auto [cached_output, output_idx] : cached_outputs) { + output_tvs.push_back(fusion->outputs()[output_idx]->as()); + } + + // Set up TMA store: reg_cache -> smem_cache (TMA) -> output + std::vector tma_store_tvs; + if (tparams->use_tma_store) { + for (auto [cached_output, output_idx] : cached_outputs) { + auto output = fusion->outputs()[output_idx]->as(); + output->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + cached_output->setMemoryType(MemoryType::Shared); + cached_output->cacheBefore(); + tma_store_tvs.push_back(cached_output); + } + } + + // Group tensors by innermost dim. Group 1 = non-swizzled, Group 2 = swizzled. + scheduler_tools::TransposeDomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + NVF_ERROR(grouped_inputs_outputs.size() >= 2); + + // When not using output smem transpose but inputs > outputs, swap groups + // so group 2 remains the swizzled side. + if (!tparams->is_output_smem_transpose && + cached_inputs.size() > cached_outputs.size()) { + std::swap(grouped_inputs_outputs[0], grouped_inputs_outputs[1]); + } + + TensorView* group1_ref = + domain_map.findReferenceFor(grouped_inputs_outputs[0]); + TensorView* group2_ref = + domain_map.findReferenceFor(grouped_inputs_outputs[1]); + NVF_ERROR( + group1_ref != nullptr, "Unable to find reference tensor for group 1"); + NVF_ERROR( + group2_ref != nullptr, "Unable to find reference tensor for group 2"); + + // Step 1: Tile two transpose dimensions on group1_ref, merge all other + // dimensions into BIDx, and propagate tiling to the entire DAG. + auto group1_inner_id = scheduler_utils::innerMostAllocDim(group1_ref); + auto group2_inner_id = scheduler_utils::innerMostAllocDim(group2_ref); + int64_t group1_inner_pos = + domain_map.getInnerLeafDim(group1_ref, group1_inner_id); + int64_t group2_inner_pos = + domain_map.getInnerLeafDim(group1_ref, group2_inner_id); + NVF_ERROR( + group1_inner_pos >= 0 && group2_inner_pos >= 0 && + group1_inner_pos != group2_inner_pos, + "Invalid inner dim positions for TMA tiling"); + + // Transform propagation can't pass through TMA tvs, select + // a tv after tma load or before tma store. + TensorView* ref_tv = group1_ref; + if (group1_ref->isFusionInput() && tparams->use_tma_load) { + auto smem_consumer = ir_utils::consumerTvsOf(group1_ref).at(0); + auto regs_consumer = ir_utils::consumerTvsOf(smem_consumer).at(0); + ref_tv = regs_consumer; + } else if (group1_ref->isFusionOutput() && tparams->use_tma_store) { + auto smem_producer = ir_utils::getSoleProducerTv(group1_ref); + auto regs_producer = ir_utils::getSoleProducerTv(smem_producer); + ref_tv = regs_producer; + } + + // Split and reorder to create tiles: + // [..., I1, .., I2, ...] → [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] + ref_tv->split(group1_inner_pos, tparams->tile_size1); + ref_tv->reorder({{group1_inner_pos + 1, -1}}); + ref_tv->split(group2_inner_pos, tparams->tile_size2); + ref_tv->reorder({{group2_inner_pos + 1, -1}}); + + // Merge all non-tiled dimensions into a single BIDx dim + int64_t rhs_i = ref_tv->nDims() - 3; + for (int64_t lhs_i = ref_tv->nDims() - 4; lhs_i >= 0; lhs_i--) { + if (ref_tv->axis(lhs_i)->isReduction() || + ref_tv->axis(lhs_i)->isDeviceDim()) { + continue; + } + if (ref_tv->axis(rhs_i)->isReduction() || + ref_tv->axis(rhs_i)->isDeviceDim()) { + rhs_i = lhs_i; + continue; + } + ref_tv->merge(lhs_i, rhs_i); + rhs_i = lhs_i; + } + ref_tv->axis(rhs_i)->parallelize(ParallelType::BIDx); + + // Propagate tiling to all TVs + TransformPropagator tiling_propagator(ref_tv); + MaxLogicalDomainInfoSpanningTree entire_dag(ref_tv); + entire_dag.traverse(&tiling_propagator); + scheduler_utils::parallelizeAllLike( + ref_tv, + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + + // Step 2: Schedule output TMA store. + if (tparams->is_output_smem_transpose) { + // Output smem path: apply TMA swizzle to output shared memory. + MmaInputSmemSwizzle swizzle = + mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0)); + for (auto output_smem_cache : tma_store_tvs) { + mma_utils::scheduleTMAStoreForMmaOutput(output_smem_cache, swizzle); + } + for (auto output : output_tvs) { + mma_utils::scheduleTMAStoreForMmaOutput(output, swizzle); + } + } else if (tparams->use_tma_store) { + // Input smem path with TMA store: reorder for contiguous output and + // set Bulk parallel on tile dims. + for (auto output_smem_cache : tma_store_tvs) { + // [.., tile1, tile2] → [.., tile2, tile1] + output_smem_cache->reorder({{-1, -2}}); + output_smem_cache->setAllocationDomain( + output_smem_cache->getLoopDomain(), true); + } + for (auto output : output_tvs) { + output->axis(-1)->parallelize(ParallelType::Bulk); + output->axis(-2)->parallelize(ParallelType::Bulk); + } + } + + // Step 3: Schedule input shared memory. + if (!tparams->is_output_smem_transpose) { + // Input smem path: TMA load into swizzled shared memory. + NVF_ERROR( + tparams->use_tma_load, + "TMA load must be used when input smem is transposed"); + for (auto input_smem_cache : tma_load_tvs) { + MmaInputSmemSwizzle swizzle_type = + mma_utils::tmaSwizzleSharedMemory(input_smem_cache); + input_smem_cache->applyMmaSwizzleForTMALoad(swizzle_type); + } + } else if (tparams->use_tma_load) { + // Output smem path: TMA load without swizzle. Reorder so tile1 + // (group 1's inner dim) is innermost for contiguous access. + for (auto input_smem_cache : tma_load_tvs) { + // [.., tile1, tile2] → [.., tile2, tile1] + input_smem_cache->reorder({{-1, -2}}); + input_smem_cache->axis(-1)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(-2)->parallelize(ParallelType::Bulk); + input_smem_cache->setAllocationDomain( + input_smem_cache->getLoopDomain(), true); + } + } + + // Step 4: Schedule register TVs for per-thread access. + // ref_tv's innermost tile is tile1 (group 1's inner dim). + // The merge order is critical for bank conflicts on the non-swizzled smem + // side. By merging tile2_outer as the outer dim and tile1 as the inner dim, + // adjacent threads in a warp access adjacent tile1 positions. Since tile1 is + // the contiguous (inner) dimension of the non-swizzled smem layout, this + // means adjacent threads read from adjacent memory addresses. + // [BIDx, tile1, tile2] + ref_tv->split(-1, tparams->elements_per_chunk); + // [BIDx, tile1, tile2/chunk, chunk] + ref_tv->split(-2, tparams->chunks_per_thread); + // [BIDx, tile1, tile2/chunk/cpt, cpt, chunk] + ref_tv->merge(-3, -4); + // [BIDx, tile2/chunk/cpt * tile1, cpt, chunk] + if (tparams->vectorize_factor1 > 1) { + ref_tv->split(-3, tparams->vectorize_factor1); + // [BIDx, tile2/chunk/cpt * tile1/vec, vec, cpt, chunk] + ref_tv->axis(-4)->parallelize(ParallelType::TIDx); + } else { + ref_tv->axis(-3)->parallelize(ParallelType::TIDx); + } + + // Propagate register scheduling to all TVs except TMA smem/output TVs. + std::unordered_set skip_tvs( + tma_load_tvs.begin(), tma_load_tvs.end()); + if (tparams->use_tma_store) { + skip_tvs.insert(output_tvs.begin(), output_tvs.end()); + } + auto reg_tvs = ir_utils::allTvsExcept(fusion, skip_tvs); + std::unordered_set reg_tvs_set(reg_tvs.begin(), reg_tvs.end()); + SetSelector selector(reg_tvs_set); + MaxLogicalDomainInfoSpanningTree reg_dag(ref_tv, &selector); + TransformPropagator reg_propagator(ref_tv); + reg_dag.traverse(®_propagator); + scheduler_utils::parallelizeAllLike( + ref_tv, + reg_tvs, + {}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + + // Vectorize smem reads at the transpose boundary: each consumer of + // a TMA-loaded smem TV reads with vectorized access. + auto vectorize_smem_reads = [&tma_load_tvs](int pos) { + for (auto tma_load_tv : tma_load_tvs) { + for (auto consumer : ir_utils::consumerTvsOf(tma_load_tv)) { + consumer->axis(pos)->parallelize(ParallelType::Vectorize); + } + } + }; + if (tparams->is_output_smem_transpose) { + // Vectorize writes to swizzled output smem (chunk dim). + for (auto output_smem_cache : tma_store_tvs) { + output_smem_cache->axis(-1)->parallelize(ParallelType::Vectorize); + } + // Vectorize reads from non-swizzled input smem (vec dim). + // [BIDx, tile2/chunk/cpt * tile1/vec, vec, cpt, chunk] + if (tparams->vectorize_factor1 > 1) { + vectorize_smem_reads(-3); + } + } else { + // Input smem path: vectorize reads from swizzled input smem (chunk dim). + vectorize_smem_reads(-1); + } + + inlineMost(); } } // namespace tma diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index bbc3d8fd30d..5fb1520c205 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -9,6 +9,7 @@ #include #include +#include "device_lower/analysis/bank_conflict.h" #include "exceptions.h" #include "ops/all_ops.h" #include "optimization_pass.h" @@ -18,6 +19,7 @@ #include "runtime/fusion_executor_cache.h" #include "scheduler/all_schedulers.h" #include "scheduler/mma_utils.h" +#include "scheduler/runtime_info.h" #include "scheduler/tools/inlining.h" #include "scheduler/transpose.h" #include "scheduler/utils.h" @@ -1820,4 +1822,222 @@ TEST_F(TransposeTMA, TransposeOutputSmem) { testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); } +// dtype, pair of transpose dimensions, inner most dim of input tensor +using TmaTransposeTestParams = + std::tuple, int64_t>; +class TmaTransposeTestP + : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeTestP, InputTranspose) { + auto dtype = std::get<0>(GetParam()); + auto [dim1, dim2] = std::get<1>(GetParam()); + auto inner_dim = std::get<2>(GetParam()); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, dim1, dim2); + fusion.addOutput(tv1); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({128, 256, inner_dim}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +TEST_P(TmaTransposeTestP, OutputTranspose) { + auto dtype = std::get<0>(GetParam()); + auto [dim1, dim2] = std::get<1>(GetParam()); + auto inner_dim = std::get<2>(GetParam()); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(3, dtype); + auto tv1 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, dim1, dim2); + fusion.addOutput(tv3); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({128, 256, inner_dim}, options); + auto t1 = at::randn({128, 256, inner_dim}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeTestP, + testing::Combine( + testing::Values(DataType::Float, DataType::BFloat16), + testing::Values( + std::make_pair(int64_t(0), int64_t(2)), + std::make_pair(int64_t(1), int64_t(2))), + testing::Values(1, 2, 32, 64, 128, 256, 512)), + [](const testing::TestParamInfo& info) { + auto dtype = std::get<0>(info.param); + auto dims = std::get<1>(info.param); + auto inner_dim = std::get<2>(info.param); + std::ostringstream os; + os << dtype << "_transpose_" << dims.first << "_" << dims.second + << "_size_" << inner_dim; + return os.str(); + }); + +class TmaTransposeDtypeP : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeDtypeP, OutputTransposeBankconflict) { + auto dtype = GetParam(); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + auto tv0 = makeContigTensor(2, dtype); + auto tv1 = makeContigTensor(2, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + tv0 = maybeCastOp(DataType::Float, tv0); + tv1 = maybeCastOp(DataType::Float, tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + tv4 = maybeCastOp(dtype, tv4); + fusion.addOutput(tv4); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({16384, 8192}, options); + auto t1 = at::randn({16384, 8192}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); + + // Check bank conflicts via the compiled kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + for (auto& executor : runtime->executors()) { + if (auto* ke = dynamic_cast(executor.get())) { + auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel()); + for (auto& [expr, ways] : bank_conflicts) { + auto [read_ways, write_ways] = ways; + std::cout << " Bank conflict: " << expr->toString() + << " read=" << read_ways << "-way" + << ", write=" << write_ways << "-way" << std::endl; + } + EXPECT_TRUE(bank_conflicts.empty()); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeDtypeP, + testing::Values(DataType::Float, DataType::BFloat16), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << info.param; + return os.str(); + }); + +// Test different combinations of TMA transpose parameters: +// (is_output_smem, use_tma_load, use_tma_store) +// (false, true, false) - input smem tranapose, TMA load only +// (false, true, true) - input smem tranapose, TMA load + TMA store +// (true, true, true) - output smem tranapose, TMA load + TMA store +// (true, false, true) - output smem tranapose, TMA store only +using TmaTransposeParamsTestP_Params = + std::tuple; // is_output_smem, tma_load, tma_store + +class TmaTransposeParamsTestP + : public TransposeTest, + public testing::WithParamInterface { + protected: + void SetUp() override { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + TransposeTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::TmaTranspose); + } +}; + +TEST_P(TmaTransposeParamsTestP, TmaTransposeParams) { + auto [is_output_smem, use_tma_load, use_tma_store] = GetParam(); + + // Build fusion: two inputs -> add -> transpose -> square -> output. + // Two inputs ensure the default heuristic picks output smem transpose. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({4096, 2048}, options); + auto t1 = at::randn({4096, 2048}, options); + + // Compute default TMA heuristics, then override the 3 params. + SchedulerRuntimeInfo runtime_info(fusion.get(), {t0, t1}); + auto scheduler = + SchedulerEntry::makeSchedulerInstance(SchedulerType::Transpose); + auto heuristic_params = + scheduler->computeHeuristics(fusion.get(), runtime_info); + auto* tparams = heuristic_params->as(); + tparams->is_output_smem_transpose = is_output_smem; + tparams->use_tma_load = use_tma_load; + tparams->use_tma_store = use_tma_store; + + // Schedule, compile, and run. + scheduler->schedule(fusion.get(), tparams); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0, t1}, tparams->lparams); + auto outputs = ke.run({t0, t1}, {}, tparams->lparams); + testValidate(fusion.get(), outputs, {t0, t1}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeTest, + TmaTransposeParamsTestP, + testing::Values( + // (is_output_smem, tma_load, tma_store) + std::make_tuple(false, true, false), // input smem, TMA load only + std::make_tuple(false, true, true), // input smem, TMA load + store + std::make_tuple(true, true, true), // output smem, TMA load + store + std::make_tuple(true, false, true)), // output smem, TMA store only + [](const testing::TestParamInfo& info) { + bool is_output_smem = std::get<0>(info.param); + bool tma_load = std::get<1>(info.param); + bool tma_store = std::get<2>(info.param); + std::ostringstream os; + os << (is_output_smem ? "output_smem_transpose" : "input_smem_transpose") + << "_load_" << (tma_load ? "tma" : "off") << "_store_" + << (tma_store ? "tma" : "off"); + return os.str(); + }); + } // namespace nvfuser