diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index c2511ec5c2f..f31a65d23cd 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator { bool should_reject = false; }; -bool hasSmallTransposeDimensions( - const std::unique_ptr& params) { +bool hasSmallTransposeDimensions(const TransposeParams* params) { return !params->split_before_tiling.empty() || !params->dims_merged_with_1.empty() || !params->dims_merged_with_2.empty(); @@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason( // 1. view op; and // 2. small transpose transformation // See note [Supporting small transpose dimensions] - if (hasSmallTransposeDimensions(params)) { + if (hasSmallTransposeDimensions(params.get())) { return "Small transpose dimensions and view op cannot be currently be " "handled by transpose scheduler. See: " "https://github.com/NVIDIA/Fuser/pull/592"; @@ -677,7 +676,7 @@ std::unique_ptr getTransposeHeuristics( inner_most_pos2_in_ref1); NVF_ERROR( - !hasSmallTransposeDimensions(tparams) || + !hasSmallTransposeDimensions(tparams.get()) || scheduler_utils::getViewTVs(fusion).empty(), "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); @@ -722,6 +721,39 @@ std::unique_ptr getTransposeHeuristics( scan_max_dtype_size(fusion->inputs()); scan_max_dtype_size(fusion->outputs()); + // Double tile_size2 if the default configuration doesn't provide enough + // bytes in flight to saturate memory bandwidth. This is based on Little's + // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight + // per SM as: (sum of input tensor element sizes) * elements_per_tile * + // blocks_per_sm. If this is less than the required bits in flight (derived + // from hardware bandwidth and memory latency), we double tile_size2 to + // increase the data in flight. If tile1 is doubled, it will also double + // shared memory bank conflict, e.g. from 8-ways to 16 ways when increased + // from 32 to 64 assuming vectorization factor is 4, we need 8 or 16 threads + // loading per column. + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t max_blocks_per_sm = dev_prop->maxThreadsPerMultiProcessor / + TransposeParams::getMaxThreadsPerBlock(); + const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; + const int64_t required_bits_per_sm = + scheduler_utils::getRequiredBitsInFlight(); + int64_t total_input_bits_per_elem = 0; + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + total_input_bits_per_elem += + dataTypeSizeBit(tv->getDataType().value(), index_type); + } + const int64_t bits_in_flight_per_sm = + total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; + std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem + << std::endl; + std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl; + std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl; + std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl; + std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl; + if (bits_in_flight_per_sm < required_bits_per_sm) { + tparams->tile_size2 *= 2; + } + auto max_unroll_factor = ceilDiv( // Available unrolling based on size of data type kSixteen / max_io_dtype_size, @@ -834,7 +866,7 @@ std::unique_ptr getTransposeHeuristics( << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; - if (hasSmallTransposeDimensions(tparams)) { + if (hasSmallTransposeDimensions(tparams.get())) { debug() << "small transposed dim, needs virtual inner-most dim" << std::endl; } @@ -912,6 +944,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { */ std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + std::vector smem_cached_input_tvs; for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; @@ -919,19 +952,31 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); + smem_cached_input_tvs.push_back(new_cache); } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); + smem_cached_input_tvs.push_back(existing_cache); } } } + + bool use_smem_swizzle = !hasSmallTransposeDimensions(tparams); // set cached outputs of group 2 to shared memory for (const auto& [cached_output, output_idx] : cached_outputs) { auto output = fusion->outputs()[output_idx]->as(); if (group2_and_cached_inputs.count(output) > 0) { cached_output->setMemoryType(MemoryType::Shared); + // current smem swizzle only works for cached input + use_smem_swizzle = false; } } + // For non-square tile, can't create smem swizzle chunks if tile2 is larger + // and not vectorized + if (tparams->tile_size2 > tparams->tile_size1 && + tparams->vectorize_factor2 == 1) { + use_smem_swizzle = false; + } TensorView* reference1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); @@ -1133,9 +1178,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // inputs and outputs themselves are disconnected, so we have to borrow the // entire DAG and use its spanning tree. { - auto all_tvs_except1 = ir_utils::allTvsExcept( - fusion, - {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); + std::unordered_set except_tvs; + except_tvs.insert( + grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } + auto all_tvs_except1 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); @@ -1226,8 +1276,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // Propagate transformations, parallelization of the reference1 to the entire // DAG except group 2 and its corresponding cached outputs. { - auto all_tvs_except2 = - ir_utils::allTvsExcept(fusion, group2_and_cached_inputs); + std::unordered_set except_tvs; + except_tvs.insert( + group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); + if (use_smem_swizzle) { + except_tvs.insert( + smem_cached_input_tvs.begin(), smem_cached_input_tvs.end()); + } + auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, except_tvs); SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs( reference1, &selector); @@ -1292,6 +1348,30 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { } } + // schedule smem_cached_input_tvs + if (use_smem_swizzle) { + for (auto tv : smem_cached_input_tvs) { + std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl; + int64_t pos = tv->nDims() - 2; + bool is_group2 = group2_and_cached_inputs.count(tv) > 0; + int64_t tile2_factor = + is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1; + int64_t tile1_factor = + tparams->tile_size1 * tile2_factor / tparams->tile_size2; + // [BIDx, UnSwitch, tile1, tile2] + tv->split(pos + 1, tile2_factor); + tv->split(pos, tile1_factor); + tv->swizzle(SwizzleType::XOR, pos, pos + 2); + tv->merge(pos); + tv->merge(pos); + tv->split(pos, tparams->getThreadsPerBlock()); + tv->axis(pos)->parallelize(ParallelType::Unroll); + tv->axis(pos + 1)->parallelize(ParallelType::TIDx); + tv->axis(pos + 2)->parallelize(ParallelType::Vectorize); + std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl; + } + } + //////////////////////////////// // Step 5: Cleanup and inline // //////////////////////////////// diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index ef0c7a92133..10130d35016 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -259,7 +259,7 @@ TEST_F(RNGTest, BroadcastingRNGSmem) { auto outputs = scheduleAndRun( - fusion, SchedulerType::Transpose, {input0, input1}, false) + fusion, SchedulerType::PointWise, {input0, input1}, false) .outputs; auto out = outputs[0].as();