From 127dc46d221c3be83e997dce0ebe150ee5208fbe Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 6 Feb 2026 07:48:53 -0800 Subject: [PATCH 01/13] increase transpose tile size to meet required bytes in flight --- benchmarks/python/test_transpose.py | 85 ++++++++++++++++++++++------- csrc/scheduler/transpose.cpp | 23 ++++++++ 2 files changed, 88 insertions(+), 20 deletions(-) diff --git a/benchmarks/python/test_transpose.py b/benchmarks/python/test_transpose.py index 11b66774708..9e91fc0b2f2 100644 --- a/benchmarks/python/test_transpose.py +++ b/benchmarks/python/test_transpose.py @@ -15,17 +15,26 @@ def transpose_fusion( is_copy_transpose: bool, axes: list, rank: int, + num_inputs: int = 2, ): shape = [-1] * rank contiguity = [True] * rank T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) - T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) + + if num_inputs == 2: + T1 = fd.define_tensor( + shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False + ) if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) - T1 = fd.ops.cast(T1, dtype=DataType.Float) + if num_inputs == 2: + T1 = fd.ops.cast(T1, dtype=DataType.Float) - T4 = fd.ops.add(T0, T1) + if num_inputs == 2: + T4 = fd.ops.add(T0, T1) + else: + T4 = T0 T5 = fd.ops.permute(T4, dims=axes) if dtype in PROMOTE_DTYPES: @@ -46,11 +55,18 @@ def transpose_fusion( # Without contiguous, transpose returns a view with swapped strides. # contiguous() materializes a contiguous copy of the result. # When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler. -def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1, is_copy_transpose] - relu_transpose_result = torch.nn.functional.relu( - torch.transpose(inputs[0] + inputs[1], inputs[2], inputs[3]) - ) - is_copy_transpose = inputs[4] +def transpose_fwd_fn( + inputs: list, +): # [input1, input2 (optional), dim0, dim1, is_copy_transpose, num_inputs] + num_inputs = inputs[-1] + is_copy_transpose = inputs[-2] + if num_inputs == 2: + data = inputs[0] + inputs[1] + dim0, dim1 = inputs[2], inputs[3] + else: + data = inputs[0] + dim0, dim1 = inputs[1], inputs[2] + relu_transpose_result = torch.nn.functional.relu(torch.transpose(data, dim0, dim1)) if is_copy_transpose: return relu_transpose_result.contiguous() else: @@ -75,6 +91,11 @@ def _generate_transpose_params(): [True, False], ids=["copy_transpose", "view_transpose"], ) +@pytest.mark.parametrize( + "num_inputs", + [1, 2], + ids=["1_input", "2_inputs"], +) @pytest.mark.pointwise def test_transpose_nvf_benchmark( benchmark, @@ -83,11 +104,16 @@ def test_transpose_nvf_benchmark( dtype: torch.dtype, axes: tuple, dims: int, + num_inputs: int, disable_validation: bool, disable_benchmarking: bool, ): input1 = torch.randn(size, device="cuda", dtype=dtype) - input2 = torch.randn(size, device="cuda", dtype=dtype) + inputs = [input1] + if num_inputs == 2: + input2 = torch.randn(size, device="cuda", dtype=dtype) + inputs.append(input2) + permute_axes = list(range(len(size))) permute_axes[axes[0]], permute_axes[axes[1]] = ( permute_axes[axes[1]], @@ -101,16 +127,22 @@ def test_transpose_nvf_benchmark( is_copy_transpose, permute_axes, rank=dims, + num_inputs=num_inputs, ) if not disable_validation: - eager_output = transpose_fwd_fn( - [input1, input2, axes[0], axes[1], is_copy_transpose] - ) - fd.validate([input1, input2], [eager_output]) + if num_inputs == 2: + eager_output = transpose_fwd_fn( + [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs] + ) + else: + eager_output = transpose_fwd_fn( + [input1, axes[0], axes[1], is_copy_transpose, num_inputs] + ) + fd.validate(inputs, [eager_output]) if not disable_benchmarking: - run_benchmark(benchmark, fd.execute, [input1, input2]) + run_benchmark(benchmark, fd.execute, inputs) @pytest.mark.parametrize("executor", DEFAULT_EXECUTORS) @@ -121,6 +153,11 @@ def test_transpose_nvf_benchmark( [True, False], ids=["copy_transpose", "view_transpose"], ) +@pytest.mark.parametrize( + "num_inputs", + [1, 2], + ids=["1_input", "2_inputs"], +) def test_transpose_baseline_benchmark( benchmark, size: tuple, @@ -128,18 +165,26 @@ def test_transpose_baseline_benchmark( is_copy_transpose: bool, axes: tuple, dims: int, + num_inputs: int, executor: str, ): if executor == "torchcompile": clear_dynamo_cache() input1 = torch.randn(size, device="cuda", dtype=dtype) - input2 = torch.randn(size, device="cuda", dtype=dtype) benchmark_fn = with_executor(executor, transpose_fwd_fn) # Inputs and outputs are same as nvFuser, no need for manual IOByte computation - run_benchmark( - benchmark, - benchmark_fn, - [input1, input2, axes[0], axes[1], is_copy_transpose], - ) + if num_inputs == 2: + input2 = torch.randn(size, device="cuda", dtype=dtype) + run_benchmark( + benchmark, + benchmark_fn, + [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs], + ) + else: + run_benchmark( + benchmark, + benchmark_fn, + [input1, axes[0], axes[1], is_copy_transpose, num_inputs], + ) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index c2511ec5c2f..7f4e5da0634 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -682,6 +682,29 @@ std::unique_ptr getTransposeHeuristics( "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); + // Double tile_size2 if the default configuration doesn't provide enough + // bytes in flight to saturate memory bandwidth. This is based on Little's + // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight + // per SM as: (sum of input tensor element sizes) * elements_per_tile * + // blocks_per_sm. If this is less than the required bits in flight (derived + // from hardware bandwidth and memory latency), we double tile_size2 to + // increase the data in flight. + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t max_blocks_per_sm = dev_prop->maxBlocksPerMultiProcessor; + const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; + const int64_t required_bits_per_sm = + scheduler_utils::getRequiredBitsInFlight(); + int64_t total_input_bits_per_elem = 0; + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + total_input_bits_per_elem += + dataTypeSizeBit(tv->getDataType().value(), index_type); + } + const int64_t bits_in_flight_per_sm = + total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; + if (bits_in_flight_per_sm < required_bits_per_sm) { + tparams->tile_size2 *= 2; + } + // Note [vectorization and unroll of input and output] // // The choice of vectorization size, block size and tile sizes needs to be From ff62c9e8d82f7601edfc1a021857152d48d7b2cc Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 6 Feb 2026 08:51:28 -0800 Subject: [PATCH 02/13] tile size and rm bank conflicts --- csrc/scheduler/transpose.cpp | 98 +++++++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 7f4e5da0634..61082f947b0 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -682,29 +682,6 @@ std::unique_ptr getTransposeHeuristics( "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); - // Double tile_size2 if the default configuration doesn't provide enough - // bytes in flight to saturate memory bandwidth. This is based on Little's - // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight - // per SM as: (sum of input tensor element sizes) * elements_per_tile * - // blocks_per_sm. If this is less than the required bits in flight (derived - // from hardware bandwidth and memory latency), we double tile_size2 to - // increase the data in flight. - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t max_blocks_per_sm = dev_prop->maxBlocksPerMultiProcessor; - const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; - const int64_t required_bits_per_sm = - scheduler_utils::getRequiredBitsInFlight(); - int64_t total_input_bits_per_elem = 0; - for (auto tv : ir_utils::filterByType(fusion->inputs())) { - total_input_bits_per_elem += - dataTypeSizeBit(tv->getDataType().value(), index_type); - } - const int64_t bits_in_flight_per_sm = - total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; - if (bits_in_flight_per_sm < required_bits_per_sm) { - tparams->tile_size2 *= 2; - } - // Note [vectorization and unroll of input and output] // // The choice of vectorization size, block size and tile sizes needs to be @@ -745,6 +722,39 @@ std::unique_ptr getTransposeHeuristics( scan_max_dtype_size(fusion->inputs()); scan_max_dtype_size(fusion->outputs()); + // Double tile_size2 if the default configuration doesn't provide enough + // bytes in flight to saturate memory bandwidth. This is based on Little's + // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight + // per SM as: (sum of input tensor element sizes) * elements_per_tile * + // blocks_per_sm. If this is less than the required bits in flight (derived + // from hardware bandwidth and memory latency), we double tile_size2 to + // increase the data in flight. If tile1 is doubled, it will also double + // shared memory bank conflict, e.g. from 8-ways to 16 ways when increased + // from 32 to 64 assuming vectorization factor is 4, we need 8 or 16 threads + // loading per column. + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t max_blocks_per_sm = dev_prop->maxThreadsPerMultiProcessor / + TransposeParams::getMaxThreadsPerBlock(); + const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; + const int64_t required_bits_per_sm = + scheduler_utils::getRequiredBitsInFlight(); + int64_t total_input_bits_per_elem = 0; + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + total_input_bits_per_elem += + dataTypeSizeBit(tv->getDataType().value(), index_type); + } + const int64_t bits_in_flight_per_sm = + total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; + std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem + << std::endl; + std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl; + std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl; + std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl; + std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl; + if (bits_in_flight_per_sm < required_bits_per_sm) { + tparams->tile_size2 *= 2; + } + auto max_unroll_factor = ceilDiv( // Available unrolling based on size of data type kSixteen / max_io_dtype_size, @@ -935,6 +945,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { */ std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + std::vector smem_cached_input_tvs; for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; @@ -942,9 +953,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); + smem_cached_input_tvs.push_back(existing_cache); } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); + smem_cached_input_tvs.push_back(existing_cache); } } } @@ -1156,9 +1169,12 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // inputs and outputs themselves are disconnected, so we have to borrow the // entire DAG and use its spanning tree. { - auto all_tvs_except1 = ir_utils::allTvsExcept( - fusion, - {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); + std::unordered_set except_tvs; + except_tvs.insert( + grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()); + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + auto all_tvs_except1 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); @@ -1249,8 +1265,12 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Propagate transformations, parallelization of the reference1 to the entire // DAG except group 2 and its corresponding cached outputs. { - auto all_tvs_except2 = - ir_utils::allTvsExcept(fusion, group2_and_cached_inputs); + std::unordered_set except_tvs; + except_tvs.insert( + group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs( reference1, &selector); @@ -1315,6 +1335,28 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } } + // schedule smem_cached_input_tvs + for (auto tv : smem_cached_input_tvs) { + std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl; + int64_t pos = tv->nDims() - 2; + bool is_group2 = group2_and_cached_inputs.count(tv) > 0; + int64_t tile2_factor = + is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1; + int64_t tile1_factor = + tparams->tile_size1 * tile2_factor / tparams->tile_size2; + // [BIDx, UnSwitch, tile1, tile2] + tv->split(pos + 1, tile2_factor); + tv->split(pos, tile1_factor); + tv->swizzle(SwizzleType::XOR, pos, pos + 2); + tv->merge(pos); + tv->merge(pos); + tv->split(pos, tparams->getThreadsPerBlock()); + tv->axis(pos)->parallelize(ParallelType::Unroll); + tv->axis(pos + 1)->parallelize(ParallelType::TIDx); + tv->axis(pos + 2)->parallelize(ParallelType::Vectorize); + std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl; + } + //////////////////////////////// // Step 5: Cleanup and inline // //////////////////////////////// From 902b2734d2962242fbf02f7a03d5454739d67383 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 8 Feb 2026 07:12:14 -0800 Subject: [PATCH 03/13] small transpose --- csrc/scheduler/transpose.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 61082f947b0..477db1fe0b8 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator { bool should_reject = false; }; -bool hasSmallTransposeDimensions( - const std::unique_ptr& params) { +bool hasSmallTransposeDimensions(const TransposeParams* params) { return !params->split_before_tiling.empty() || !params->dims_merged_with_1.empty() || !params->dims_merged_with_2.empty(); @@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason( // 1. view op; and // 2. small transpose transformation // See note [Supporting small transpose dimensions] - if (hasSmallTransposeDimensions(params)) { + if (hasSmallTransposeDimensions(params.get())) { return "Small transpose dimensions and view op cannot be currently be " "handled by transpose scheduler. See: " "https://github.com/NVIDIA/Fuser/pull/592"; @@ -677,7 +676,7 @@ std::unique_ptr getTransposeHeuristics( inner_most_pos2_in_ref1); NVF_ERROR( - !hasSmallTransposeDimensions(tparams) || + !hasSmallTransposeDimensions(tparams.get()) || scheduler_utils::getViewTVs(fusion).empty(), "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); @@ -867,7 +866,7 @@ std::unique_ptr getTransposeHeuristics( << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; - if (hasSmallTransposeDimensions(tparams)) { + if (hasSmallTransposeDimensions(tparams.get())) { debug() << "small transposed dim, needs virtual inner-most dim" << std::endl; } @@ -953,11 +952,15 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); - smem_cached_input_tvs.push_back(existing_cache); + if (!hasSmallTransposeDimensions(tparams)) { + smem_cached_input_tvs.push_back(new_cache); + } } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); - smem_cached_input_tvs.push_back(existing_cache); + if (!hasSmallTransposeDimensions(tparams)) { + smem_cached_input_tvs.push_back(existing_cache); + } } } } From 3c561b584ebfe7ce79e950952846b72527c55aed Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 8 Feb 2026 07:41:28 -0800 Subject: [PATCH 04/13] skip reduction dim --- csrc/scheduler/tools/domain_map.cpp | 12 +++++++++--- tests/cpp/test_transpose.cpp | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index f7cfc277877..5341ce9d381 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -540,8 +540,14 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { const auto& ref1_loop = ref1->getMaybeAllocationDomain(); const auto& ref2_loop = ref2->getMaybeAllocationDomain(); const auto& ca_map = domain_map.getComputeAtMap(); + + // Filter out reduction domains before comparing + auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); }; + auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction); + auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction); + const bool all_mapped = std::ranges::equal( - ref1_loop, ref2_loop, [&](IterDomain* id1, IterDomain* id2) { + ref1_filtered, ref2_filtered, [&](IterDomain* id1, IterDomain* id2) { return ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE); }); if (all_mapped) { @@ -549,9 +555,9 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { // any_bcast const bool any_bcast = std::ranges::any_of( - ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) || + ref1_filtered, [](IterDomain* id) { return id->isBroadcast(); }) || std::ranges::any_of( - ref2_loop, [](IterDomain* id) { return id->isBroadcast(); }); + ref2_filtered, [](IterDomain* id) { return id->isBroadcast(); }); NVF_ERROR( any_bcast, "all_mapped implies any_bcast, ca_map:\n", diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 139a2b1e791..1642c3af509 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1371,7 +1371,7 @@ TEST_F(TransposeTest, ReductionIterDomainOnInputsIssue1659) { .get() ->scheduler_type; NVF_CHECK( - heuristic1 == SchedulerType::Transpose, + heuristic1 == SchedulerType::PointWise, "Unexpected heuristic: ", heuristic1); testValidate(fusion_ptr, cg_outputs, {t0, t1}, __LINE__, __FILE__); From a0d20ff817e809d0b67d8b8b1d05bbac629abcfe Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 8 Feb 2026 08:36:15 -0800 Subject: [PATCH 05/13] fix transpose --- csrc/scheduler/transpose.cpp | 70 +++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 477db1fe0b8..f31a65d23cd 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -952,25 +952,31 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); - if (!hasSmallTransposeDimensions(tparams)) { - smem_cached_input_tvs.push_back(new_cache); - } + smem_cached_input_tvs.push_back(new_cache); } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); - if (!hasSmallTransposeDimensions(tparams)) { - smem_cached_input_tvs.push_back(existing_cache); - } + smem_cached_input_tvs.push_back(existing_cache); } } } + + bool use_smem_swizzle = !hasSmallTransposeDimensions(tparams); // set cached outputs of group 2 to shared memory for (const auto& [cached_output, output_idx] : cached_outputs) { auto output = fusion->outputs()[output_idx]->as(); if (group2_and_cached_inputs.count(output) > 0) { cached_output->setMemoryType(MemoryType::Shared); + // current smem swizzle only works for cached input + use_smem_swizzle = false; } } + // For non-square tile, can't create smem swizzle chunks if tile2 is larger + // and not vectorized + if (tparams->tile_size2 > tparams->tile_size1 && + tparams->vectorize_factor2 == 1) { + use_smem_swizzle = false; + } TensorView* reference1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); @@ -1175,8 +1181,10 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { std::unordered_set except_tvs; except_tvs.insert( grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()); - except_tvs.insert( - smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } auto all_tvs_except1 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); @@ -1271,8 +1279,10 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { std::unordered_set except_tvs; except_tvs.insert( group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); - except_tvs.insert( - smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs( @@ -1339,25 +1349,27 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } // schedule smem_cached_input_tvs - for (auto tv : smem_cached_input_tvs) { - std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl; - int64_t pos = tv->nDims() - 2; - bool is_group2 = group2_and_cached_inputs.count(tv) > 0; - int64_t tile2_factor = - is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1; - int64_t tile1_factor = - tparams->tile_size1 * tile2_factor / tparams->tile_size2; - // [BIDx, UnSwitch, tile1, tile2] - tv->split(pos + 1, tile2_factor); - tv->split(pos, tile1_factor); - tv->swizzle(SwizzleType::XOR, pos, pos + 2); - tv->merge(pos); - tv->merge(pos); - tv->split(pos, tparams->getThreadsPerBlock()); - tv->axis(pos)->parallelize(ParallelType::Unroll); - tv->axis(pos + 1)->parallelize(ParallelType::TIDx); - tv->axis(pos + 2)->parallelize(ParallelType::Vectorize); - std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl; + if (use_smem_swizzle) { + for (auto tv : smem_cached_input_tvs) { + std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl; + int64_t pos = tv->nDims() - 2; + bool is_group2 = group2_and_cached_inputs.count(tv) > 0; + int64_t tile2_factor = + is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1; + int64_t tile1_factor = + tparams->tile_size1 * tile2_factor / tparams->tile_size2; + // [BIDx, UnSwitch, tile1, tile2] + tv->split(pos + 1, tile2_factor); + tv->split(pos, tile1_factor); + tv->swizzle(SwizzleType::XOR, pos, pos + 2); + tv->merge(pos); + tv->merge(pos); + tv->split(pos, tparams->getThreadsPerBlock()); + tv->axis(pos)->parallelize(ParallelType::Unroll); + tv->axis(pos + 1)->parallelize(ParallelType::TIDx); + tv->axis(pos + 2)->parallelize(ParallelType::Vectorize); + std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl; + } } //////////////////////////////// From 6e53109137466a161ceb3551fd61c5a63471ae52 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 8 Feb 2026 08:36:56 -0800 Subject: [PATCH 06/13] remove invalid test --- tests/cpp/test_gpu3.cpp | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index ed0139f0a7a..944eee99ef1 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -3740,28 +3740,6 @@ TEST_F(Gpu3Test, FusionDependencyCheck_CUDA) { } } -// Repro for issue #1925 -TEST_F(Gpu3Test, FusionScheduleTransposeRepro1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(4); - auto tv1 = makeConcreteTensor({-1, -1, -1, 1}); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({1, 1, 333, 1}, options); - at::Tensor t1 = at::randn({1, 1, 333, 1}, options); - - auto cg_outputs = - scheduleAndRun(&fusion, SchedulerType::Transpose, {input0, t1}, false) - .outputs; - testValidate(&fusion, cg_outputs, {input0, t1}, __LINE__, __FILE__); -} - TEST_F(Gpu3Test, FusionPredicateUnshare_CUDA) { // https://github.com/csarofeen/pytorch/issues/1926 std::unique_ptr fusion_ptr = std::make_unique(); From a9de12f5805c193ee94077da22ea46f20d285171 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 8 Feb 2026 12:42:03 -0800 Subject: [PATCH 07/13] remove invalid test --- tests/cpp/test_rng.cpp | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index 42cca55513d..ef0c7a92133 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -270,42 +270,6 @@ TEST_F(RNGTest, BroadcastingRNGSmem) { } } -TEST_F(RNGTest, BroadcastingRNGSmemNonSquareTile) { - // https://github.com/csarofeen/pytorch/issues/1926 - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeConcreteTensor({5, 1}); - TensorView* tv1 = makeConcreteTensor({5, 5}); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = rand_like(tv0); - auto tv3 = add(tv1, tv2); - auto tv4 = add(tv0, tv3); - fusion->addOutput(tv4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::zeros({5, 1}, options); - at::Tensor t1 = at::zeros({5, 5}, options); - - TransposeParams tparams; - tparams.tile_size1 = 8; - tparams.tile_size2 = 4; - SchedulerEntry::makeSchedulerInstance(SchedulerType::Transpose) - ->schedule(fusion, &tparams); - - KernelExecutor ke; - ke.compile(fusion, {t0, t1}); - auto cg_outputs = ke.run({t0, t1}); - auto out = cg_outputs[0].as(); - - NVF_CHECK((out.select(1, 0) == out.select(1, 1)).all().item()); - NVF_CHECK((out.select(1, 0) == out.select(1, 2)).all().item()); - NVF_CHECK((out.select(1, 0) == out.select(1, 3)).all().item()); - NVF_CHECK((out.select(1, 0) == out.select(1, 4)).all().item()); -} - TEST_F(RNGTest, Uniform) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); From 8e9d9088a32ed6d65c168ae9be73397a486bafbb Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 9 Feb 2026 07:15:28 -0800 Subject: [PATCH 08/13] limit bcast aliasing to io tensors --- csrc/alias_analysis.cpp | 12 +++++++++++- tests/cpp/test_persistent_buffer.cpp | 7 +++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 81c5cb3cabf..7b1e54095ba 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -288,12 +288,22 @@ void AliasFinder::handle(const SliceOp* slice) { out, in, Layout(std::move(out_allocation), std::move(out_contiguity))); } +// Only consider broadcast aliasing when input is a fusion input and output is +// a fusion output. Intermediate broadcasts will be fused with other ops and +// don't need explicit alias handling. Limiting to fusion boundaries avoids +// unnecessary allocation domain changes on intermediate tensors, which may +// trigger transpose scheduler when pointwise is preferred. For example, when a +// normalization kernel is segmented, we prefer reduction + pointwise instead of +// reduction + transpose. See SmemPersistentNotSupportedIn3DReduction. void AliasFinder::handle(const BroadcastOp* bcast) { auto* in = dynamic_cast(bcast->in()); - if (in == nullptr) { + if (in == nullptr || !in->isFusionInput()) { return; } auto* out = bcast->out()->as(); + if (!out->isFusionOutput()) { + return; + } std::optional out_layout = mapInLayoutToOutRoot(analysis_.preferredLayout(in), in, out); diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index da9c5c906dd..9a0923b3f11 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -1257,6 +1257,13 @@ TEST_F(PersistentBufferTest, SmemPersistentNotSupportedIn3DReduction) { // persistent is not supported yet for 3D reduction. EXPECT_TRUE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + // expect reduction and pointwise scheduler + EXPECT_THAT( + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(), + UnorderedElementsAre( + HeuristicIs(SchedulerType::PointWise), + HeuristicIs(SchedulerType::Reduction))); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } From 13c94a2772f88d2c37ae19f06dc58b070bf08bb6 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 10 Feb 2026 06:31:15 -0800 Subject: [PATCH 09/13] update alias --- csrc/alias_analysis.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 7b1e54095ba..1cf3221a43b 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -288,20 +288,22 @@ void AliasFinder::handle(const SliceOp* slice) { out, in, Layout(std::move(out_allocation), std::move(out_contiguity))); } -// Only consider broadcast aliasing when input is a fusion input and output is -// a fusion output. Intermediate broadcasts will be fused with other ops and -// don't need explicit alias handling. Limiting to fusion boundaries avoids -// unnecessary allocation domain changes on intermediate tensors, which may -// trigger transpose scheduler when pointwise is preferred. For example, when a -// normalization kernel is segmented, we prefer reduction + pointwise instead of -// reduction + transpose. See SmemPersistentNotSupportedIn3DReduction. +// Only consider broadcast aliasing when IO tensor is involved. +// Intermediate broadcasts will be fused with other ops and don't need explicit +// alias handling. Limiting to fusion boundaries avoids unnecessary allocation +// domain changes on intermediate tensors, which may trigger transpose scheduler +// when pointwise is preferred. For example, when a normalization kernel is +// segmented, we prefer reduction + pointwise instead of reduction + transpose. +// See SmemPersistentNotSupportedIn3DReduction. void AliasFinder::handle(const BroadcastOp* bcast) { auto* in = dynamic_cast(bcast->in()); - if (in == nullptr || !in->isFusionInput()) { + if (in == nullptr) { return; } auto* out = bcast->out()->as(); - if (!out->isFusionOutput()) { + + // No alias analysis needed if no IO tensors are involved + if (!out->isFusionOutput() && !in->isFusionInput()) { return; } From 14320b34db220352e97be82f8c7d4d2abeb7daec Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 12 Feb 2026 08:13:59 -0800 Subject: [PATCH 10/13] solve conf --- benchmarks/python/test_transpose.py | 209 ++++++++++++++++------------ 1 file changed, 120 insertions(+), 89 deletions(-) diff --git a/benchmarks/python/test_transpose.py b/benchmarks/python/test_transpose.py index 9e91fc0b2f2..b38427f561c 100644 --- a/benchmarks/python/test_transpose.py +++ b/benchmarks/python/test_transpose.py @@ -9,47 +9,79 @@ from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES -def transpose_fusion( +def transpose_fusion_input_smem( fd: FusionDefinition, dtype: DataType, is_copy_transpose: bool, axes: list, rank: int, - num_inputs: int = 2, ): + """Single input: the transposed input is read through shared memory.""" shape = [-1] * rank contiguity = [True] * rank T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) - if num_inputs == 2: - T1 = fd.define_tensor( - shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False - ) - if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) - if num_inputs == 2: - T1 = fd.ops.cast(T1, dtype=DataType.Float) - if num_inputs == 2: - T4 = fd.ops.add(T0, T1) + T1 = fd.ops.permute(T0, dims=axes) + + if dtype in PROMOTE_DTYPES: + T1 = fd.ops.cast(T1, dtype=dtype) + + S2 = fd.define_scalar(0.00000, dtype=DataType.Double) + T3 = fd.ops.gt(T1, S2) + T4 = fd.ops.where(T3, T1, S2) + # add segmenter set to avoid presegment passes setting the output as a + # view of the input without any data movement. It leads to pointwise + # instead of transpose scheduler. + # we can also expose OptimizationPassGuard to python frontend and disable + # presegmentation passes to enforce output to be contiguous and then + # transpose scheduler will be used. + if is_copy_transpose: + T5 = fd.ops.segment_set(T4) + fd.add_output(T5) else: - T4 = T0 - T5 = fd.ops.permute(T4, dims=axes) + fd.add_output(T4) + + +def transpose_fusion_output_smem( + fd: FusionDefinition, + dtype: DataType, + is_copy_transpose: bool, + axes: list, + rank: int, +): + """Two inputs: the transposed output is written through shared memory.""" + shape = [-1] * rank + contiguity = [True] * rank + T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) + T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) if dtype in PROMOTE_DTYPES: - T5 = fd.ops.cast(T5, dtype=dtype) + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) - S6 = fd.define_scalar(0.00000, dtype=DataType.Double) - T7 = fd.ops.gt(T5, S6) - T9 = fd.ops.where(T7, T5, S6) - # add segmenter set to avoid presegment passes setting the output as a view of the input without any data movement. It leads to pointwise instead of transpose scheduler. - # we can also expose OptimizationPassGuard to python frontend and disable presegmentation passes to enforce output to be contiguous and then transpose scheduler will be used. + T2 = fd.ops.add(T0, T1) + T3 = fd.ops.permute(T2, dims=axes) + + if dtype in PROMOTE_DTYPES: + T3 = fd.ops.cast(T3, dtype=dtype) + + S4 = fd.define_scalar(0.00000, dtype=DataType.Double) + T5 = fd.ops.gt(T3, S4) + T6 = fd.ops.where(T5, T3, S4) + # add segmenter set to avoid presegment passes setting the output as a + # view of the input without any data movement. It leads to pointwise + # instead of transpose scheduler. + # we can also expose OptimizationPassGuard to python frontend and disable + # presegmentation passes to enforce output to be contiguous and then + # transpose scheduler will be used. if is_copy_transpose: - T10 = fd.ops.segment_set(T9) - fd.add_output(T10) + T7 = fd.ops.segment_set(T6) + fd.add_output(T7) else: - fd.add_output(T9) + fd.add_output(T6) # Without contiguous, transpose returns a view with swapped strides. @@ -57,27 +89,65 @@ def transpose_fusion( # When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler. def transpose_fwd_fn( inputs: list, -): # [input1, input2 (optional), dim0, dim1, is_copy_transpose, num_inputs] - num_inputs = inputs[-1] - is_copy_transpose = inputs[-2] - if num_inputs == 2: +): # [input1, input2 (optional), axes, is_copy_transpose] + is_copy_transpose = inputs[-1] + axes = inputs[-2] + if len(inputs) == 4: data = inputs[0] + inputs[1] - dim0, dim1 = inputs[2], inputs[3] else: data = inputs[0] - dim0, dim1 = inputs[1], inputs[2] - relu_transpose_result = torch.nn.functional.relu(torch.transpose(data, dim0, dim1)) + relu_transpose_result = torch.nn.functional.relu(data.permute(axes)) if is_copy_transpose: return relu_transpose_result.contiguous() else: return relu_transpose_result +def setup_input_smem(size, dtype, is_copy_transpose, axes, dims): + """Single input: the transposed input is read through shared memory.""" + input1 = torch.randn(size, device="cuda", dtype=dtype) + nvfuser_inputs = [input1] + + with FusionDefinition() as fd: + transpose_fusion_input_smem( + fd, + torch_dtype_to_nvfuser_dtype(dtype), + is_copy_transpose, + axes, + rank=dims, + ) + + eager_inputs = [input1, axes, is_copy_transpose] + return fd, nvfuser_inputs, eager_inputs + + +def setup_output_smem(size, dtype, is_copy_transpose, axes, dims): + """Two inputs: the transposed output is written through shared memory.""" + input1 = torch.randn(size, device="cuda", dtype=dtype) + input2 = torch.randn(size, device="cuda", dtype=dtype) + nvfuser_inputs = [input1, input2] + + with FusionDefinition() as fd: + transpose_fusion_output_smem( + fd, + torch_dtype_to_nvfuser_dtype(dtype), + is_copy_transpose, + axes, + rank=dims, + ) + + eager_inputs = [input1, input2, axes, is_copy_transpose] + return fd, nvfuser_inputs, eager_inputs + + def _generate_transpose_params(): params = [] for dims in (2, 3): sizes = generate_input_sizes(dims=dims) - axes_list = [(0, 1)] if dims == 2 else [(0, 1), (0, 2), (1, 2)] + if dims == 2: + axes_list = [(1, 0)] + else: + axes_list = [(1, 0, 2), (2, 1, 0), (0, 2, 1)] for size in sizes: for axes in axes_list: params.append((size, axes, dims)) @@ -88,15 +158,16 @@ def _generate_transpose_params(): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize( "is_copy_transpose", - [True, False], - ids=["copy_transpose", "view_transpose"], + [ + pytest.param(True, marks=pytest.mark.transpose, id="copy"), + pytest.param(False, marks=pytest.mark.pointwise, id="view"), + ], ) @pytest.mark.parametrize( - "num_inputs", - [1, 2], - ids=["1_input", "2_inputs"], + "setup_fn", + [setup_input_smem, setup_output_smem], + ids=["input_smem", "output_smem"], ) -@pytest.mark.pointwise def test_transpose_nvf_benchmark( benchmark, size: tuple, @@ -104,45 +175,20 @@ def test_transpose_nvf_benchmark( dtype: torch.dtype, axes: tuple, dims: int, - num_inputs: int, + setup_fn, disable_validation: bool, disable_benchmarking: bool, ): - input1 = torch.randn(size, device="cuda", dtype=dtype) - inputs = [input1] - if num_inputs == 2: - input2 = torch.randn(size, device="cuda", dtype=dtype) - inputs.append(input2) - - permute_axes = list(range(len(size))) - permute_axes[axes[0]], permute_axes[axes[1]] = ( - permute_axes[axes[1]], - permute_axes[axes[0]], + fd, nvfuser_inputs, eager_inputs = setup_fn( + size, dtype, is_copy_transpose, axes, dims ) - with FusionDefinition() as fd: - transpose_fusion( - fd, - torch_dtype_to_nvfuser_dtype(dtype), - is_copy_transpose, - permute_axes, - rank=dims, - num_inputs=num_inputs, - ) - if not disable_validation: - if num_inputs == 2: - eager_output = transpose_fwd_fn( - [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs] - ) - else: - eager_output = transpose_fwd_fn( - [input1, axes[0], axes[1], is_copy_transpose, num_inputs] - ) - fd.validate(inputs, [eager_output]) + eager_output = transpose_fwd_fn(eager_inputs) + fd.validate(nvfuser_inputs, [eager_output]) if not disable_benchmarking: - run_benchmark(benchmark, fd.execute, inputs) + run_benchmark(benchmark, fd.execute, nvfuser_inputs) @pytest.mark.parametrize("executor", DEFAULT_EXECUTORS) @@ -151,12 +197,12 @@ def test_transpose_nvf_benchmark( @pytest.mark.parametrize( "is_copy_transpose", [True, False], - ids=["copy_transpose", "view_transpose"], + ids=["copy", "view"], ) @pytest.mark.parametrize( - "num_inputs", - [1, 2], - ids=["1_input", "2_inputs"], + "setup_fn", + [setup_input_smem, setup_output_smem], + ids=["input_smem", "output_smem"], ) def test_transpose_baseline_benchmark( benchmark, @@ -165,26 +211,11 @@ def test_transpose_baseline_benchmark( is_copy_transpose: bool, axes: tuple, dims: int, - num_inputs: int, + setup_fn, executor: str, ): if executor == "torchcompile": clear_dynamo_cache() - input1 = torch.randn(size, device="cuda", dtype=dtype) - + _, _, eager_inputs = setup_fn(size, dtype, is_copy_transpose, axes, dims) benchmark_fn = with_executor(executor, transpose_fwd_fn) - - # Inputs and outputs are same as nvFuser, no need for manual IOByte computation - if num_inputs == 2: - input2 = torch.randn(size, device="cuda", dtype=dtype) - run_benchmark( - benchmark, - benchmark_fn, - [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs], - ) - else: - run_benchmark( - benchmark, - benchmark_fn, - [input1, axes[0], axes[1], is_copy_transpose, num_inputs], - ) + run_benchmark(benchmark, benchmark_fn, eager_inputs) From bd2d2b84bfc3fe9b31bf6e58404ab159bcb68794 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 12 Feb 2026 08:16:14 -0800 Subject: [PATCH 11/13] revert --- csrc/alias_analysis.cpp | 12 ------------ tests/cpp/test_persistent_buffer.cpp | 7 ------- 2 files changed, 19 deletions(-) diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 1cf3221a43b..81c5cb3cabf 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -288,13 +288,6 @@ void AliasFinder::handle(const SliceOp* slice) { out, in, Layout(std::move(out_allocation), std::move(out_contiguity))); } -// Only consider broadcast aliasing when IO tensor is involved. -// Intermediate broadcasts will be fused with other ops and don't need explicit -// alias handling. Limiting to fusion boundaries avoids unnecessary allocation -// domain changes on intermediate tensors, which may trigger transpose scheduler -// when pointwise is preferred. For example, when a normalization kernel is -// segmented, we prefer reduction + pointwise instead of reduction + transpose. -// See SmemPersistentNotSupportedIn3DReduction. void AliasFinder::handle(const BroadcastOp* bcast) { auto* in = dynamic_cast(bcast->in()); if (in == nullptr) { @@ -302,11 +295,6 @@ void AliasFinder::handle(const BroadcastOp* bcast) { } auto* out = bcast->out()->as(); - // No alias analysis needed if no IO tensors are involved - if (!out->isFusionOutput() && !in->isFusionInput()) { - return; - } - std::optional out_layout = mapInLayoutToOutRoot(analysis_.preferredLayout(in), in, out); if (!out_layout.has_value()) { diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index 9a0923b3f11..da9c5c906dd 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -1257,13 +1257,6 @@ TEST_F(PersistentBufferTest, SmemPersistentNotSupportedIn3DReduction) { // persistent is not supported yet for 3D reduction. EXPECT_TRUE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); - // expect reduction and pointwise scheduler - EXPECT_THAT( - executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(), - UnorderedElementsAre( - HeuristicIs(SchedulerType::PointWise), - HeuristicIs(SchedulerType::Reduction))); - testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } From a8f82f1daa038c666cf29e78e18cea0c3f67719c Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 16 Feb 2026 17:14:25 -0800 Subject: [PATCH 12/13] remove bank conflict --- csrc/scheduler/transpose.cpp | 100 +++++++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index c2511ec5c2f..f31a65d23cd 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator { bool should_reject = false; }; -bool hasSmallTransposeDimensions( - const std::unique_ptr& params) { +bool hasSmallTransposeDimensions(const TransposeParams* params) { return !params->split_before_tiling.empty() || !params->dims_merged_with_1.empty() || !params->dims_merged_with_2.empty(); @@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason( // 1. view op; and // 2. small transpose transformation // See note [Supporting small transpose dimensions] - if (hasSmallTransposeDimensions(params)) { + if (hasSmallTransposeDimensions(params.get())) { return "Small transpose dimensions and view op cannot be currently be " "handled by transpose scheduler. See: " "https://github.com/NVIDIA/Fuser/pull/592"; @@ -677,7 +676,7 @@ std::unique_ptr getTransposeHeuristics( inner_most_pos2_in_ref1); NVF_ERROR( - !hasSmallTransposeDimensions(tparams) || + !hasSmallTransposeDimensions(tparams.get()) || scheduler_utils::getViewTVs(fusion).empty(), "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); @@ -722,6 +721,39 @@ std::unique_ptr getTransposeHeuristics( scan_max_dtype_size(fusion->inputs()); scan_max_dtype_size(fusion->outputs()); + // Double tile_size2 if the default configuration doesn't provide enough + // bytes in flight to saturate memory bandwidth. This is based on Little's + // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight + // per SM as: (sum of input tensor element sizes) * elements_per_tile * + // blocks_per_sm. If this is less than the required bits in flight (derived + // from hardware bandwidth and memory latency), we double tile_size2 to + // increase the data in flight. If tile1 is doubled, it will also double + // shared memory bank conflict, e.g. from 8-ways to 16 ways when increased + // from 32 to 64 assuming vectorization factor is 4, we need 8 or 16 threads + // loading per column. + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t max_blocks_per_sm = dev_prop->maxThreadsPerMultiProcessor / + TransposeParams::getMaxThreadsPerBlock(); + const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; + const int64_t required_bits_per_sm = + scheduler_utils::getRequiredBitsInFlight(); + int64_t total_input_bits_per_elem = 0; + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + total_input_bits_per_elem += + dataTypeSizeBit(tv->getDataType().value(), index_type); + } + const int64_t bits_in_flight_per_sm = + total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; + std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem + << std::endl; + std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl; + std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl; + std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl; + std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl; + if (bits_in_flight_per_sm < required_bits_per_sm) { + tparams->tile_size2 *= 2; + } + auto max_unroll_factor = ceilDiv( // Available unrolling based on size of data type kSixteen / max_io_dtype_size, @@ -834,7 +866,7 @@ std::unique_ptr getTransposeHeuristics( << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; - if (hasSmallTransposeDimensions(tparams)) { + if (hasSmallTransposeDimensions(tparams.get())) { debug() << "small transposed dim, needs virtual inner-most dim" << std::endl; } @@ -912,6 +944,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { */ std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + std::vector smem_cached_input_tvs; for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; @@ -919,19 +952,31 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); + smem_cached_input_tvs.push_back(new_cache); } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); + smem_cached_input_tvs.push_back(existing_cache); } } } + + bool use_smem_swizzle = !hasSmallTransposeDimensions(tparams); // set cached outputs of group 2 to shared memory for (const auto& [cached_output, output_idx] : cached_outputs) { auto output = fusion->outputs()[output_idx]->as(); if (group2_and_cached_inputs.count(output) > 0) { cached_output->setMemoryType(MemoryType::Shared); + // current smem swizzle only works for cached input + use_smem_swizzle = false; } } + // For non-square tile, can't create smem swizzle chunks if tile2 is larger + // and not vectorized + if (tparams->tile_size2 > tparams->tile_size1 && + tparams->vectorize_factor2 == 1) { + use_smem_swizzle = false; + } TensorView* reference1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); @@ -1133,9 +1178,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // inputs and outputs themselves are disconnected, so we have to borrow the // entire DAG and use its spanning tree. { - auto all_tvs_except1 = ir_utils::allTvsExcept( - fusion, - {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); + std::unordered_set except_tvs; + except_tvs.insert( + grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } + auto all_tvs_except1 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); @@ -1226,8 +1276,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Propagate transformations, parallelization of the reference1 to the entire // DAG except group 2 and its corresponding cached outputs. { - auto all_tvs_except2 = - ir_utils::allTvsExcept(fusion, group2_and_cached_inputs); + std::unordered_set except_tvs; + except_tvs.insert( + group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } + auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs( reference1, &selector); @@ -1292,6 +1348,30 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } } + // schedule smem_cached_input_tvs + if (use_smem_swizzle) { + for (auto tv : smem_cached_input_tvs) { + std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl; + int64_t pos = tv->nDims() - 2; + bool is_group2 = group2_and_cached_inputs.count(tv) > 0; + int64_t tile2_factor = + is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1; + int64_t tile1_factor = + tparams->tile_size1 * tile2_factor / tparams->tile_size2; + // [BIDx, UnSwitch, tile1, tile2] + tv->split(pos + 1, tile2_factor); + tv->split(pos, tile1_factor); + tv->swizzle(SwizzleType::XOR, pos, pos + 2); + tv->merge(pos); + tv->merge(pos); + tv->split(pos, tparams->getThreadsPerBlock()); + tv->axis(pos)->parallelize(ParallelType::Unroll); + tv->axis(pos + 1)->parallelize(ParallelType::TIDx); + tv->axis(pos + 2)->parallelize(ParallelType::Vectorize); + std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl; + } + } + //////////////////////////////// // Step 5: Cleanup and inline // //////////////////////////////// From ae27c289fdd9ca6aa2977b124d8785ee7dcf0392 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 17 Feb 2026 07:20:09 -0800 Subject: [PATCH 13/13] use pointwise scheduler --- tests/cpp/test_rng.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index ef0c7a92133..10130d35016 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -259,7 +259,7 @@ TEST_F(RNGTest, BroadcastingRNGSmem) { auto outputs = scheduleAndRun( - fusion, SchedulerType::Transpose, {input0, input1}, false) + fusion, SchedulerType::PointWise, {input0, input1}, false) .outputs; auto out = outputs[0].as();