From dbdb3eb591d6297ef5da269886d3c8e266e1fd03 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 28 Nov 2025 04:23:11 -0800 Subject: [PATCH 1/2] [xla:gpu] Fix clangtidy warnings (except `misc-include-cleaner`). PiperOrigin-RevId: 837800257 --- .../gpu/autotuner/block_level_emitter.h | 2 +- .../gpu/codegen/emitters/reduction.cc | 4 +- .../backends/gpu/codegen/emitters/scatter.cc | 4 +- .../emitters/transforms/convert_float_amd.cc | 3 +- .../xla/backends/gpu/codegen/fusion_emitter.h | 2 +- .../gpu/codegen/tools/gpu_test_correctness.cc | 12 +++- .../xla/xla/backends/gpu/codegen/triton/BUILD | 1 + .../triton/compilation_pipeline_test.cc | 12 ++-- .../gpu/codegen/triton/dot_algorithms.cc | 31 +++++----- .../gpu/codegen/triton/emitter_helpers.cc | 17 ++---- .../gpu/codegen/triton/ir/triton_xla_attrs.cc | 4 +- .../gpu/codegen/triton/support_legacy.cc | 12 ++-- .../gpu/codegen/triton/support_legacy_test.cc | 5 +- .../backends/gpu/codegen/triton/test_utils.cc | 13 ++--- .../transforms/generalize_kernel_signature.cc | 4 +- .../round_f32_to_tf32_for_tf32_dot_pass.cc | 16 ++++-- .../triton_xla_lower_block_barrier_pass.cc | 4 +- .../backends/gpu/collectives/gpu_clique.cc | 8 ++- .../gpu/collectives/gpu_clique_key.cc | 7 ++- .../backends/gpu/collectives/gpu_cliques.cc | 12 +++- .../gpu/collectives/nccl_collectives.cc | 4 +- .../gpu/collectives/nvshmem_collectives.cc | 4 +- .../backends/gpu/runtime/buffer_comparator.cc | 4 +- .../gpu/runtime/collective_metadata_thunk.cc | 2 +- .../gpu/runtime/collective_permute_thunk.cc | 38 +++++++------ .../gpu/runtime/command_buffer_cmd.cc | 57 +++++++++++-------- .../gpu/runtime/command_buffer_cmd_test.cc | 2 +- .../gpu/runtime/dynamic_slice_thunk.cc | 4 +- .../gpu/runtime/dynamic_slice_thunk.h | 3 +- .../gpu/runtime/host_execute_thunk.cc | 4 +- .../gpu/runtime/host_send_recv_thunk.cc | 35 +++++++++--- .../gpu/runtime/make_batch_pointers.cc | 3 +- .../backends/gpu/runtime/p2p_thunk_common.cc | 4 +- .../backends/gpu/runtime/p2p_thunk_common.h | 4 +- .../xla/backends/gpu/runtime/while_thunk.cc | 2 +- 35 files changed, 208 insertions(+), 135 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/autotuner/block_level_emitter.h b/third_party/xla/xla/backends/gpu/autotuner/block_level_emitter.h index d3a3845de79efd..4a36a827ba4767 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/block_level_emitter.h +++ b/third_party/xla/xla/backends/gpu/autotuner/block_level_emitter.h @@ -67,7 +67,7 @@ class BlockLevelEmitterBackend : public GpuCodegenBackend { const BackendConfig& config) override; // Determines whether the given HLO instruction is supported by this backend. - bool IsSupported(const HloInstruction& instr); + bool IsSupported(const HloInstruction& instr) override; // We don't want to use the Triton emitter as a reference because it can // produce wrong results. diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/reduction.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/reduction.cc index 192292d8f81964..19936a63675ce8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/reduction.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/reduction.cc @@ -587,7 +587,9 @@ SmallVector ReductionFusion::EvaluateEpilogue( EmitterState& state, int group_id, ValueRange symbol_values) const { ImplicitLocOpBuilder& b = state.builder; const auto& epilogue = state.computations.epilogues()[group_id]; - if (epilogue.roots.empty()) return outputs; + if (epilogue.roots.empty()) { + return outputs; + } auto epilogue_input_indices = state.thread_and_block_ids; epilogue_input_indices.append(symbol_values.begin(), symbol_values.end()); diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc index 0b43566026296b..7a2ba02dfda734 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc @@ -208,7 +208,9 @@ SmallVector Unpack(ValueRange range, ArrayRef sizes) { SmallVector PadWithZeros(ValueRange values, int64_t size, ImplicitLocOpBuilder& b) { SmallVector padded_values(values.begin(), values.end()); - if (values.size() >= size) return padded_values; + if (values.size() >= size) { + return padded_values; + } auto zero = arith::ConstantIndexOp::create(b, 0); for (int i = values.size(); i < size; ++i) { padded_values.push_back(zero); diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc index d8f239b8f6d340..8dead6470b4763 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc @@ -221,9 +221,8 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { } if (v.getType() != f32_ty) { return arith::TruncFOp::create(b, f32_ty, v); - } else { - return v; } + return v; }); mlir::StringAttr cvtIntr = b.getStringAttr( diff --git a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.h b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.h index df0b3257b09a2d..27081f974b6d22 100644 --- a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.h +++ b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.h @@ -61,7 +61,7 @@ class FusionInterface { // Interface for fusions that are implemented using cuda kernels. class KernelFusionInterface : public FusionInterface { public: - virtual ~KernelFusionInterface() = default; + ~KernelFusionInterface() override = default; // Returns the fusion's launch dimensions. virtual LaunchDimensions launch_dimensions() const = 0; diff --git a/third_party/xla/xla/backends/gpu/codegen/tools/gpu_test_correctness.cc b/third_party/xla/xla/backends/gpu/codegen/tools/gpu_test_correctness.cc index 02ab435155c921..edc204a7ef8833 100644 --- a/third_party/xla/xla/backends/gpu/codegen/tools/gpu_test_correctness.cc +++ b/third_party/xla/xla/backends/gpu/codegen/tools/gpu_test_correctness.cc @@ -72,7 +72,9 @@ absl::Status TestBijection(const IndexingMap& map, intervals.push_back({0, size - 1}); } auto status = VerifyBijection(map, intervals); - if (status.ok()) return status; + if (status.ok()) { + return status; + } return absl::FailedPreconditionError( absl::StrCat(status.message(), " in map ", ToString(map))); } @@ -162,7 +164,9 @@ int main(int argc, char* argv[]) { tsl::Flag( "bijection_inputs", [](std::string name_and_ids) { - if (name_and_ids.empty()) return false; + if (name_and_ids.empty()) { + return false; + } flags.bijection_inputs.push_back( xla::gpu::ParseHeroAndIds(name_and_ids)); return true; @@ -174,7 +178,9 @@ int main(int argc, char* argv[]) { tsl::Flag( "bijection_outputs", [](std::string name) { - if (name.empty()) return false; + if (name.empty()) { + return false; + } flags.bijection_outputs.push_back(name); return true; }, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index e0f4ebb023c00c..0a343917ff0cba 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -234,6 +234,7 @@ xla_cc_test( deps = [ ":compilation_pipeline", "//xla/stream_executor/cuda:cuda_compute_capability", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_test.cc index 2c9272ce02b4f7..69d5d43dcb20e6 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_join.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -48,12 +49,11 @@ TEST(CompilationPipelineTest, UnswitchLoopsAfterLICM) { } ASSERT_THAT(pass_names, Contains("LoopInvariantCodeMotion")); ASSERT_THAT(pass_names, Contains("TritonXLAUnswitchLoopsPass")); - int licm_index = std::distance(pass_names.begin(), - std::find(pass_names.begin(), pass_names.end(), - "LoopInvariantCodeMotion")); - int unswitch_index = std::distance( - pass_names.begin(), std::find(pass_names.begin(), pass_names.end(), - "TritonXLAUnswitchLoopsPass")); + int licm_index = std::distance( + pass_names.begin(), absl::c_find(pass_names, "LoopInvariantCodeMotion")); + int unswitch_index = + std::distance(pass_names.begin(), + absl::c_find(pass_names, "TritonXLAUnswitchLoopsPass")); // There is no hard requirement to run LICM **immediately** before the loop // unswitcher but you should consider if the newly added pass might interact // with the loop unswitcher. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc index 3a10731c685166..a545d8979b9b5a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -229,23 +229,20 @@ absl::StatusOr> GetForceOperandsType( // If there is a single allowed operand type, we force the operands to use // this type. return allowed_operands_types.front(); - - } else { - // If there are several allowed operand types, we just check that the - // operands have the same type, and that this type is one of the allowed - // ones. Raise an error otherwise. - if (lhs_type != rhs_type || - !absl::c_linear_search(allowed_operands_types, lhs_type)) { - std::string allowed_operands_types_str = absl::StrJoin( - allowed_operands_types, ", ", [&](std::string* out, Type type) { - absl::StrAppend(out, MlirToString(type)); - }); - return absl::FailedPreconditionError(absl::StrCat( - "Expected dot operands to both have the same type, and for this type " - "to be one of the following types: ", - allowed_operands_types_str, " but got ", MlirToString(lhs_type), - " and ", MlirToString(rhs_type))); - } + } // If there are several allowed operand types, we just check that the + // operands have the same type, and that this type is one of the allowed + // ones. Raise an error otherwise. + if (lhs_type != rhs_type || + !absl::c_linear_search(allowed_operands_types, lhs_type)) { + std::string allowed_operands_types_str = absl::StrJoin( + allowed_operands_types, ", ", [&](std::string* out, Type type) { + absl::StrAppend(out, MlirToString(type)); + }); + return absl::FailedPreconditionError(absl::StrCat( + "Expected dot operands to both have the same type, and for this type " + "to be one of the following types: ", + allowed_operands_types_str, " but got ", MlirToString(lhs_type), + " and ", MlirToString(rhs_type))); } return std::nullopt; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index 30dd0f0dcc157d..0335f7de89549c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -281,9 +281,8 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { if (src_fp_element_ty.getFPMantissaWidth() > dst_fp_element_ty.getFPMantissaWidth()) { return ma::TruncFOp::create(b, dst_ty, value); - } else { - return ma::ExtFOp::create(b, dst_ty, value); } + return ma::ExtFOp::create(b, dst_ty, value); } // int => int if (mlir::isa(src_element_ty) && @@ -321,16 +320,14 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { auto cst_int = [&](int64_t x) -> Value { if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()); - } else { - return CreateConst(b, dst_element_ty, x); } + return CreateConst(b, dst_element_ty, x); }; auto cst_float = [&](int64_t x) -> Value { if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()); - } else { - return CreateConst(b, src_fp_element_ty, x); } + return CreateConst(b, src_fp_element_ty, x); }; auto fptosi = ma::FPToSIOp::create(b, dst_ty, value); int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth()); @@ -358,9 +355,8 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { Value Subtract(mlir::ImplicitLocOpBuilder& b, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return ma::SubIOp::create(b, values[0], values[1]); - } else { - return ma::SubFOp::create(b, values[0], values[1]); } + return ma::SubFOp::create(b, values[0], values[1]); } Value Compare(mlir::ImplicitLocOpBuilder& b, ValueRange values, @@ -560,10 +556,9 @@ absl::StatusOr> EmitConstant( if (constant.shape().element_type() == U64) { return CreateConst(b, ty, ScalarConstantValue(constant, U64), shape); - } else { - return CreateConst(b, ty, ScalarConstantValue(constant, S64), - shape); } + return CreateConst(b, ty, ScalarConstantValue(constant, S64), + shape); } return CreateConst(b, ty, ScalarConstantValue(constant, F64), shape); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc index 1f1a00741bf05d..22ac2726d560e7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc @@ -31,7 +31,9 @@ static mlir::ParseResult parseI64ArrayAttr(mlir::AsmParser& parser, mlir::DenseI64ArrayAttr& array) { array = mlir::dyn_cast_or_null( mlir::DenseI64ArrayAttr::parse(parser, mlir::Type{})); - if (!array) return mlir::failure(); + if (!array) { + return mlir::failure(); + } return mlir::success(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc index 07e88ba8c458e8..8f40ccb743c02f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc @@ -119,7 +119,9 @@ CodegenDecision IsInstructionSupportsDataTypes( const auto operand_type = operand->shape().element_type(); switch (instr.opcode()) { case HloOpcode::kConvert: - if (operand_type == S4) continue; + if (operand_type == S4) { + continue; + } [[fallthrough]]; default: if (!IsTritonSupportedDataType(operand_type, gpu_version)) { @@ -206,8 +208,9 @@ CodegenDecision CanTritonHandleElementwise( } if (instr.opcode() == HloOpcode::kConstant) { return CodegenDecision::Allow(); - } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( - instr.opcode(), instr.operand(0)->shape().element_type())) { + } + if (!IsTritonSupportedElementwiseUpToFloatNormalization( + instr.opcode(), instr.operand(0)->shape().element_type())) { return CodegenDecision::Forbid("Unsupported elementwise operation."); } return CodegenDecision::Allow(); @@ -360,7 +363,8 @@ CodegenDecision IsTritonSupportedDynamicSlice( for (int i = 0; i < input->shape().dimensions().size(); ++i) { if (i == majormost_dim_id) { continue; - } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { + } + if (input->shape().dimensions(i) != instr.slice_sizes(i)) { return CodegenDecision::Forbid( "Unsupported dynamic slice on non-major-most dimension."); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy_test.cc index 5deccb9b8ba61e..379cb0b660f712 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy_test.cc @@ -289,10 +289,9 @@ ENTRY e { ApplyFloatNormalization(dot.Module().get(), GetComputeCapability())); EXPECT_TRUE(RunAndCompareNoHloPasses( std::move(dot.Module()), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); - } else { - EXPECT_THAT(TritonFusionAnalysis::Execute(dot.TritonComputation()), - absl_testing::StatusIs(absl::StatusCode::kFailedPrecondition)); } + EXPECT_THAT(TritonFusionAnalysis::Execute(dot.TritonComputation()), + absl_testing::StatusIs(absl::StatusCode::kFailedPrecondition)); } INSTANTIATE_TEST_SUITE_P( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc index 6673811c7c6b31..c0d18094dd6b70 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc @@ -89,7 +89,8 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) { if (cc.IsCuda()) { return cc.cuda_compute_capability()->IsAtLeast( se::CudaComputeCapability::kAmpere); - } else if (cc.IsRocm()) { + } + if (cc.IsRocm()) { return cc.rocm_compute_capability()->has_bf16_dtype_support(); } CHECK(false); @@ -248,10 +249,9 @@ std::string ComputeCapabilityToString( const stream_executor::GpuComputeCapability& cc) { if (auto* cuda_cc = cc.cuda_compute_capability()) { return absl::StrReplaceAll(cuda_cc->ToString(), {{".", ""}}); - } else { - CHECK(cc.IsRocm()); - return "rocm"; } + CHECK(cc.IsRocm()); + return "rocm"; } std::string TritonSupportTestTypeAndDeviceToString( @@ -329,10 +329,9 @@ absl::Status ConvertEntryToTritonFusion(HloModule* module, gpu::GpuBackendConfig gpu_config; if (use_nested_gemm_fusions) { gpu_config.mutable_fusion_backend_config()->set_kind( - std::string(kTritonNestedGemmFusionKind)); + kTritonNestedGemmFusionKind); } else { - gpu_config.mutable_fusion_backend_config()->set_kind( - std::string(kTritonFusionKind)); + gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); } TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc index d4138468bfcdea..0a19aa0f294fe8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc @@ -60,7 +60,9 @@ void StripParameterAddressSpaces(RewriterBase& rewriter, SmallVector generic_func_params( llvm::map_range(func_ty.getParams(), [](Type type) -> Type { auto ptr_ty = dyn_cast(type); - if (!ptr_ty) return type; + if (!ptr_ty) { + return type; + } if (ptr_ty.getAddressSpace() != NVVM::NVVMMemorySpace::Global) { return type; } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/round_f32_to_tf32_for_tf32_dot_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/round_f32_to_tf32_for_tf32_dot_pass.cc index 00c791c02a4dc8..c47e19b6286b97 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/round_f32_to_tf32_for_tf32_dot_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/round_f32_to_tf32_for_tf32_dot_pass.cc @@ -50,10 +50,18 @@ class Tf32DotPattern : public OpRewritePattern { mlir::LogicalResult matchAndRewrite( mt::DotOp op, PatternRewriter &rewriter) const override { constexpr auto tf32_args_rounded = "tf32_arguments_rounded"; - if (op.getInputPrecision() != mt::InputPrecision::TF32) return failure(); - if (!op.getA().getType().getElementType().isF32()) return failure(); - if (!op.getB().getType().getElementType().isF32()) return failure(); - if (op->hasAttr(tf32_args_rounded)) return failure(); + if (op.getInputPrecision() != mt::InputPrecision::TF32) { + return failure(); + } + if (!op.getA().getType().getElementType().isF32()) { + return failure(); + } + if (!op.getB().getType().getElementType().isF32()) { + return failure(); + } + if (op->hasAttr(tf32_args_rounded)) { + return failure(); + } auto f32ToTF32 = [&](Value value) -> Value { return ElementwiseInlineAsmOp::create( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_block_barrier_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_block_barrier_pass.cc index 5d83980a13e349..c095f12a8eb2af 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_block_barrier_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_block_barrier_pass.cc @@ -136,7 +136,7 @@ LogicalResult LowerBlockBarrierOp(BlockBarrierOp block_barrier, // Signal all ranks on the same block id. mlir::triton::xla::AtomicWriteOp::create( builder, - /*result_types=*/mlir::TypeRange{}, + /*resultTypes=*/mlir::TypeRange{}, /*ptr=*/signal_addresses, /*signal_value=*/signal_value, /*mask=*/mlir::Value{}, @@ -173,7 +173,7 @@ LogicalResult LowerBlockBarrierOp(BlockBarrierOp block_barrier, // Wait for all ranks on the same block id to signal. mlir::triton::xla::AtomicSpinWaitOp::create( builder, - /*result_types=*/mlir::TypeRange{}, + /*resultTypes=*/mlir::TypeRange{}, /*ptr=*/wait_addresses, /*expected=*/signal_value, /*mask=*/mlir::Value{}, diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc index 01b31b2b844b28..896ef40f61d0f6 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc @@ -52,7 +52,9 @@ std::string GpuClique::DebugString() const { num_communicators()); int32_t cnt = 0; ForEachComm([&](RankId rank, Communicator* comm) { - if (cnt++) absl::StrAppend(&out, ", "); + if (cnt++) { + absl::StrAppend(&out, ", "); + } absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm); }); return out; @@ -63,7 +65,9 @@ absl::Status GpuClique::HealthCheck() const { ForEachComm([&health_check](RankId rank, Communicator* comm) { if (auto s = comm->HealthCheck(); !s.ok()) { LOG(ERROR) << "GPU communicator error (rank " << rank << "): " << s; - if (health_check.ok()) health_check = std::move(s); // return first error + if (health_check.ok()) { + health_check = std::move(s); // return first error + } } }); return health_check; diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc index 6d12361b56047f..cd9fdb0f4f225a 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc @@ -43,10 +43,13 @@ bool IsP2PStreamKind(AsyncStreamKind stream_kind) { CollectiveStreamId GetCollectiveStreamId(bool is_async, CollectiveStreamId stream_id, AsyncStreamKind stream_kind) { - if (!is_async) return CollectiveStreamId(0); + if (!is_async) { + return CollectiveStreamId(0); + } // TODO: Remove this fallback once AsyncStreamId is used everywhere. - if (stream_id.value() == 0) + if (stream_id.value() == 0) { return CollectiveStreamId(static_cast(stream_kind) + 1); + } return stream_id; } diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc index f78d12490c7d28..c6a48b0ab219ac 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc @@ -156,7 +156,9 @@ static void CheckClique(const GpuCliqueKey& clique_key, << " for async errors; num_communicators=" << clique->num_communicators(); clique->ForEachComm([](RankId rank, Communicator* comm) { - if (auto status = CheckComm(comm); !status.ok()) LOG(ERROR) << status; + if (auto status = CheckComm(comm); !status.ok()) { + LOG(ERROR) << status; + } }); } else { VLOG(5) << "Skip checking in-use GPU clique " << clique_key.ToString(); @@ -225,7 +227,9 @@ static absl::StatusOr EnablePeerAccess( for (int64_t i = 0; i < devices.size(); ++i) { for (int64_t j = 0; j < devices.size(); ++j) { // An attempt to enable peer access to itself will fail. - if (i == j) continue; + if (i == j) { + continue; + } // To check if peer access is possible, we need to enable it and check // the result. OkStatus means that peer access is possible. @@ -692,7 +696,9 @@ absl::StatusOr> AcquireGpuClique( WarnStuckTimeout(), TerminateTimeout())); // If lock is not null return it to the caller. - if (*clique) return clique; + if (*clique) { + return clique; + } // Maybe find if we acquired a clique with communicators that we can split. static const int64_t enable_nccl_comm_splitting = diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index 2f900d8d119f17..c7d584957ed791 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -95,7 +95,9 @@ bool NcclCollectives::IsGlobalConfig() const { absl::StatusOr NcclCollectives::GetCliqueIdCallback(const CliqueIdCallback* clique_id_callback, bool is_local) { - if (clique_id_callback != nullptr) return clique_id_callback; + if (clique_id_callback != nullptr) { + return clique_id_callback; + } TF_RET_CHECK(is_local || IsGlobalConfig()) << "If non-local devices are taking part of a collective API on " diff --git a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc index 8b2f42edb6548c..215f42b89fab0e 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc @@ -42,7 +42,9 @@ limitations under the License. namespace xla::gpu { NvshmemCollectives::~NvshmemCollectives() { - if (initialized_) Finalize(); + if (initialized_) { + Finalize(); + } } NvshmemCollectives* NvshmemCollectives::Default() { diff --git a/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc index 8f06af032c53b7..9a61904ebf00b5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc @@ -147,7 +147,9 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { std::abs(expected_value_canonical)) + 1) < params.relative_tol)) { - if (!params.verbose) return false; // Return immediately if not verbose. + if (!params.verbose) { + return false; // Return immediately if not verbose. + } ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value << ", expected " << expected_value; diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc index 79948fe7e2800b..314b469cd70d99 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc @@ -127,7 +127,7 @@ absl::Status CollectiveMetadataThunk::ConstructCollectiveMetadata( clique_key.ToString())); } metadata.multicast_buffer_ptr = multimem_address_space; - TF_RET_CHECK(rendezvous_values->size() > 0) + TF_RET_CHECK(!rendezvous_values->empty()) << "Not enough devices in the clique."; const size_t num_parameters = (*rendezvous_values)[0].parameters.size(); for (const auto& value : *rendezvous_values) { diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc index efe521872692c2..a7ad06547e7a89 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc @@ -72,8 +72,12 @@ bool IsLocalPeerTransfer(const P2PConfig::SourceTargetMapEntry& source_target, // We determine if it's a local peer if the source/target id is within a node // if they are present. int64_t host_id = (current_id / device_count); - if (source_id && host_id != *source_id / device_count) return false; - if (target_id && host_id != *target_id / device_count) return false; + if (source_id && host_id != *source_id / device_count) { + return false; + } + if (target_id && host_id != *target_id / device_count) { + return false; + } return true; } @@ -197,15 +201,13 @@ absl::Status CollectivePermuteStartThunk::Initialize( if (source_id) { std::vector dest_addrs; - std::transform(device_buffers.begin(), device_buffers.end(), - std::back_inserter(dest_addrs), - [](const DeviceBufferPair& buffer) { - return buffer.destination_buffer; - }); + absl::c_transform(device_buffers, std::back_inserter(dest_addrs), + [](const DeviceBufferPair& buffer) { + return buffer.destination_buffer; + }); std::vector dest_opaques; - std::transform( - dest_addrs.begin(), dest_addrs.end(), - std::back_inserter(dest_opaques), + absl::c_transform( + dest_addrs, std::back_inserter(dest_opaques), [](se::DeviceMemoryBase dest_addr) { return dest_addr.opaque(); }); TF_RETURN_IF_ERROR(recv_ptr_map_.PutRecvPtr(current_id, dest_opaques)); } @@ -368,11 +370,11 @@ absl::Status RunCollectivePermute( std::optional target_id = source_target.target; std::vector src_addrs, dest_addrs; - std::transform( - buffers.begin(), buffers.end(), std::back_inserter(src_addrs), + absl::c_transform( + buffers, std::back_inserter(src_addrs), [](const DeviceBufferPair& buffer) { return buffer.source_buffer; }); - std::transform( - buffers.begin(), buffers.end(), std::back_inserter(dest_addrs), + absl::c_transform( + buffers, std::back_inserter(dest_addrs), [](const DeviceBufferPair& buffer) { return buffer.destination_buffer; }); VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d", @@ -386,8 +388,12 @@ absl::Status RunCollectivePermute( std::optional source_rank; std::vector target_ranks; - if (source_id) source_rank = RankId(*source_id); - if (target_id) target_ranks.push_back(RankId(*target_id)); + if (source_id) { + source_rank = RankId(*source_id); + } + if (target_id) { + target_ranks.push_back(RankId(*target_id)); + } if (!is_nccl_group_needed) { for (uint64_t idx = 0; idx < buffers.size(); ++idx) { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 39d4af6011a36f..07b60b5bacfeb6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -518,12 +518,24 @@ absl::Status CommandBufferCmdExecutor::Record( } } - if (has_input && !has_output && !has_temp) input_count++; - if (!has_input && has_output && !has_temp) output_count++; - if (has_input && !has_output && has_temp) input_temp_count++; - if (!has_input && has_output && has_temp) output_temp_count++; - if (has_input && has_output && !has_temp) input_output_count++; - if (has_input && has_output && has_temp) input_temp_output_count++; + if (has_input && !has_output && !has_temp) { + input_count++; + } + if (!has_input && has_output && !has_temp) { + output_count++; + } + if (has_input && !has_output && has_temp) { + input_temp_count++; + } + if (!has_input && has_output && has_temp) { + output_temp_count++; + } + if (has_input && has_output && !has_temp) { + input_output_count++; + } + if (has_input && has_output && has_temp) { + input_temp_output_count++; + } } VLOG(5) << "CommandBufferCmdExecutor allocation summary:\n" @@ -1620,23 +1632,22 @@ absl::StatusOr WhileCmd::Record( return command_buffer->UpdateChildCommand( se::CommandBuffer::ChildCommandType::kMoved, command, record_fn); }); - } else { - return Handle( - std::move(record_action), - [&](absl::Span dependencies) { - return command_buffer->CreateWhile( - se::DeviceMemory(pred), - CreateCommands(&cond_commands_, &execute_params, &record_params), - CreateCommands(&body_commands_, &execute_params, &record_params), - dependencies); - }, - [&](const se::CommandBuffer::Command* command) { - return command_buffer->UpdateWhile( - command, se::DeviceMemory(pred), - UpdateCommands(&cond_commands_, &execute_params, &record_params), - UpdateCommands(&body_commands_, &execute_params, &record_params)); - }); } + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->CreateWhile( + se::DeviceMemory(pred), + CreateCommands(&cond_commands_, &execute_params, &record_params), + CreateCommands(&body_commands_, &execute_params, &record_params), + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->UpdateWhile( + command, se::DeviceMemory(pred), + UpdateCommands(&cond_commands_, &execute_params, &record_params), + UpdateCommands(&body_commands_, &execute_params, &record_params)); + }); } bool WhileCmd::requires_initialization() { @@ -2816,7 +2827,7 @@ absl::StatusOr DynamicSliceFusionCmd::Record( CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() const { CommandBufferCmd::BufferUseVector buffers; auto embed_buffers = embedded_commands_.buffers(); - for (auto buffer_usage : embed_buffers) { + for (const auto& buffer_usage : embed_buffers) { buffers.emplace_back( *embeded_to_origin_slice_map_.at(buffer_usage.slice().index()), buffer_usage.access()); diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index ff9a89276ee6c5..19f4087fec88be 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -89,7 +89,7 @@ static constexpr auto serialize = // buffer cmd commands. We never execute this command, we need it only to pass // buffer usage vector to the command buffer cmd commands. struct TestOnlyCommandBufferCmd : public CommandBufferCmd { - TestOnlyCommandBufferCmd(BufferUseVector buffer_usage) + explicit TestOnlyCommandBufferCmd(BufferUseVector buffer_usage) : CommandBufferCmd(CommandBufferCmdType::kUnknownCmd, {}), buffer_usage(buffer_usage) {} diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index 7a28e90821371c..4fd5ac8eca55c4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -258,7 +258,9 @@ absl::Status DynamicSliceThunk::Initialize(const InitializeParams& params) { TF_RETURN_IF_ERROR(embedded_thunk_->Initialize(params)); absl::MutexLock lock(mutex_); - if (offsets_allocs_.contains(params.executor)) return absl::OkStatus(); + if (offsets_allocs_.contains(params.executor)) { + return absl::OkStatus(); + } VLOG(2) << "Allocate " << offsets_allocs_size_ << " bytes for transferring offsets on executor: " << params.executor; diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h index c41bc02853cb2b..0724d3a10b7acc 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h @@ -191,9 +191,8 @@ class DynamicSliceThunk : public Thunk { get_offset_function() const { if (offset_as_function_of_indvar_metadata_.has_value()) { return &offset_as_function_of_indvar_metadata_.value(); - } else { - return std::nullopt; } + return std::nullopt; } private: diff --git a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc index 264c0dec22538c..8ff8be373520fb 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc @@ -383,9 +383,9 @@ HostExecuteStartThunk::Create( const HostOffloadingExecutableProto& host_offloading_executable_proto, absl::InlinedVector args, absl::InlinedVector results) { - auto thunk = absl::WrapUnique(new HostExecuteStartThunk( + auto thunk = std::make_unique( std::move(thunk_info), host_offloading_executable_proto, std::move(args), - std::move(results))); + std::move(results)); if (host_offloading_executable_proto.has_aot_compilation_result()) { TF_RETURN_IF_ERROR(thunk->LoadExecutable()); } diff --git a/third_party/xla/xla/backends/gpu/runtime/host_send_recv_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/host_send_recv_thunk.cc index cce2f2795712a1..8ad8e3093b823b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_send_recv_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_send_recv_thunk.cc @@ -55,7 +55,9 @@ using tsl::profiler::TraceMeEncode; static absl::StatusOr ShouldSkip( absl::string_view operation, const Thunk::ExecuteParams& params, const std::optional& device_constraint) { - if (!device_constraint.has_value()) return false; + if (!device_constraint.has_value()) { + return false; + } GlobalDeviceId global_device_id = params.collective_params->global_device_id; bool skip = global_device_id != *device_constraint; @@ -77,8 +79,9 @@ absl::Status HostSendRecvAsyncEvents::Emplace( Key key = {executor, channel_id}; absl::MutexLock lock(mutex_); - if (auto it = events_.try_emplace(key, std::move(event)); it.second) + if (auto it = events_.try_emplace(key, std::move(event)); it.second) { return absl::OkStatus(); + } return absl::InternalError(absl::StrFormat( "Async send/recv event already exists (channel_id=%d)", channel_id)); @@ -90,7 +93,9 @@ HostSendRecvAsyncEvents::Extract(se::StreamExecutor* executor, Key key = {executor, channel_id}; absl::MutexLock lock(mutex_); - if (auto event = events_.extract(key)) return std::move(event.mapped()); + if (auto event = events_.extract(key)) { + return std::move(event.mapped()); + } return absl::InternalError(absl::StrFormat( "Async send/recv event was not found (channel_id==%d)", channel_id)); @@ -166,7 +171,9 @@ absl::Status HostSendThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("sending buffer", params, device_constraint_)); - if (skip) return absl::OkStatus(); + if (skip) { + return absl::OkStatus(); + } TraceMe trace( [&] { return TraceMeEncode("Send", {{"channel_id", channel_id_}}); }); @@ -259,7 +266,9 @@ absl::Status HostSendDoneThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for send completion", params, device_constraint_)); - if (skip) return absl::OkStatus(); + if (skip) { + return absl::OkStatus(); + } TraceMe trace( [&] { return TraceMeEncode("SendDone", {{"channel_id", channel_id_}}); }); @@ -269,7 +278,9 @@ absl::Status HostSendDoneThunk::ExecuteOnStream(const ExecuteParams& params) { // Wait until send handler will record an event on the stream. BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); + if (done_event.IsError()) { + return done_event.GetError(); + } VLOG(5) << "Completed Send operation: channel_id=" << channel_id_; @@ -356,7 +367,9 @@ absl::Status HostRecvThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN( bool skip, ShouldSkip("receiving buffer", params, device_constraint_)); - if (skip) return absl::OkStatus(); + if (skip) { + return absl::OkStatus(); + } TraceMe trace( [&] { return TraceMeEncode("Recv", {{"channel_id", channel_id_}}); }); @@ -449,7 +462,9 @@ absl::Status HostRecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for recv completion", params, device_constraint_)); - if (skip) return absl::OkStatus(); + if (skip) { + return absl::OkStatus(); + } TraceMe trace( [&] { return TraceMeEncode("RecvDone", {{"channel_id", channel_id_}}); }); @@ -459,7 +474,9 @@ absl::Status HostRecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) { // Wait until send handler will record an event on the stream. BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); + if (done_event.IsError()) { + return done_event.GetError(); + } VLOG(5) << "Completed Recv operation: channel=" << channel_id_; diff --git a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc index bff8fe2bb77c48..477cea2b60828a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc @@ -39,9 +39,8 @@ absl::Status MakeBatchPointers(se::Stream* stream, if (executor->GetPlatform()->id() == stream_executor::rocm::kROCmPlatformId) { return 256; - } else { - return 128; } + return 128; }(); TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc index a4ea9df511a284..21f250136ca36f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc +++ b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc @@ -49,7 +49,9 @@ absl::Status ExecutionCounters::Initialize(se::StreamExecutor* executor, RunId run_id) { absl::MutexLock lock(mu_); CounterKey key = {executor, run_id}; - if (counters_.contains(key)) return absl::OkStatus(); + if (counters_.contains(key)) { + return absl::OkStatus(); + } counters_.emplace(key, 0); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.h b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.h index b45544cc609b46..a3aadbacfc7bd0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.h +++ b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.h @@ -77,7 +77,9 @@ struct P2PConfig { static SourceTargetMapEntry GetSourceTarget( const IdToSourceTargetMap& id_to_source_target, int64_t id) { auto it = id_to_source_target.find(id); - if (it != id_to_source_target.end()) return it->second; + if (it != id_to_source_target.end()) { + return it->second; + } return SourceTargetMapEntry{}; } diff --git a/third_party/xla/xla/backends/gpu/runtime/while_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/while_thunk.cc index 307a4bb59378dd..341cf4c5caabec 100644 --- a/third_party/xla/xla/backends/gpu/runtime/while_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/while_thunk.cc @@ -61,7 +61,7 @@ static std::list& RunningLoops() { return loops; } -bool WhileThunk::RunningWhileThunkLoop() { return RunningLoops().size() > 0; } +bool WhileThunk::RunningWhileThunkLoop() { return !RunningLoops().empty(); } absl::StatusOr WhileThunk::CurrentLoopIteration(int64_t depth) { if (depth >= RunningLoops().size()) { From 17f201de4f56d2aab3a9168bbbcc8207d3531154 Mon Sep 17 00:00:00 2001 From: Srijan Upadhyay <104912634+CodersAcademy006@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:10:21 +0000 Subject: [PATCH 2/2] Fix CUDA invalid resource handle and memory copy failure in tf.nn.conv3d with large tensors by adding fallback mechanism --- tensorflow/core/kernels/conv_ops_impl.h | 156 +++++++++++++----------- 1 file changed, 82 insertions(+), 74 deletions(-) diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index 0d3fc798bbe3c2..1111d31f9e6e80 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -750,23 +750,30 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, absl::InternalError("No GPU stream available.")); - Tensor input = input_param; - - int spatial_dims = input.dims() - 2; - std::vector in_dims(spatial_dims); - - const int64_t in_batch = GetTensorDim(input, data_format, 'N'); - for (int i = 0; i < spatial_dims; ++i) { - in_dims[i] = GetTensorDim(input, data_format, static_cast('0' + i)); + const int64_t max_chunk_size = 1 << 26; // Example: 64MB chunks + if (input_param.NumElements() > max_chunk_size) { + // Split input tensor into smaller chunks + auto input_chunks = SplitLargeTensor(context, input_param, max_chunk_size); + Tensor temp_output; + OP_REQUIRES_OK(context, context->allocate_temp(output->dtype(), output->shape(), &temp_output)); + + for (const auto& chunk : input_chunks) { + LaunchConvOpImpl(context, cudnn_use_autotune, chunk, filter, dilations, strides, padding, explicit_paddings, data_format, &temp_output); + cudaMemcpy(output->flat().data(), temp_output.flat().data(), temp_output.NumElements() * sizeof(T), cudaMemcpyDeviceToDevice); + } + return; } - const int64_t in_depth = GetTensorDim(input, data_format, 'C'); - std::vector filter_dims(spatial_dims); - for (int i = 0; i < spatial_dims; ++i) { + // Existing implementation for smaller tensors + const int64_t in_batch = GetTensorDim(input_param, data_format, 'N'); + const int64_t in_depth = GetTensorDim(input_param, data_format, 'C'); + + std::vector filter_dims(filter.dims()); + for (int i = 0; i < filter.dims(); ++i) { filter_dims[i] = filter.dim_size(i); } - const int64_t filter_depth = filter.dim_size(spatial_dims); - const int64_t out_depth = filter.dim_size(spatial_dims + 1); + const int64_t filter_depth = filter.dim_size(filter.dims() - 2); + const int64_t out_depth = filter.dim_size(filter.dims() - 1); OP_REQUIRES( context, filter.NumElements() > 0, @@ -778,15 +785,15 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, bool one_filter = true; bool one_dilations = true; bool one_stride = true; - for (int i = 0; i < spatial_dims; ++i) { + for (int i = 0; i < filter_dims.size() - 2; ++i) { one_filter = one_filter && (filter_dims[i] == 1); - one_dilations = one_dilations && (dilations[i] == 1); - one_stride = one_stride && (strides[i] == 1); + one_dilations = one_dilations && (GetTensorDim(dilations, data_format, static_cast(i + '0')) == 1); + one_stride = one_stride && (GetTensorDim(strides, data_format, static_cast(i + '0')) == 1); } // check if filter is same spatial shape as input bool filter_same_dims = true; - for (int i = 0; i < spatial_dims; ++i) { - if (filter_dims[i] != in_dims[i]) filter_same_dims = false; + for (int i = 0; i < filter_dims.size() - 2; ++i) { + if (filter_dims[i] != GetTensorDim(input_param, data_format, static_cast(i + '0'))) filter_same_dims = false; } auto* blas = stream->parent()->AsBlas(); @@ -795,13 +802,13 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, if (!is_grouped_convolution && one_filter && one_dilations && one_stride && data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) { // 1x1 filter, so call cublas directly. - const uint64 m = in_batch * std::accumulate(in_dims.begin(), in_dims.end(), + const uint64 m = in_batch * std::accumulate(filter_dims.begin(), filter_dims.end() - 2, 1, std::multiplies<>{}); const uint64 k = in_depth; const uint64 n = out_depth; - auto a_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); + auto a_ptr = AsDeviceMemory(input_param.template flat().data(), + input_param.template flat().size()); auto b_ptr = AsDeviceMemory(filter.template flat().data(), filter.template flat().size()); auto c_ptr = AsDeviceMemory(output->template flat().data(), @@ -818,12 +825,12 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, // The input data and filter have the same spatial dimensions, so call // cublas directly. const uint64 m = in_batch; - const uint64 k = in_depth * std::accumulate(in_dims.begin(), in_dims.end(), + const uint64 k = in_depth * std::accumulate(filter_dims.begin(), filter_dims.end() - 2, 1, std::multiplies<>{}); const uint64 n = out_depth; - auto a_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); + auto a_ptr = AsDeviceMemory(input_param.template flat().data(), + input_param.template flat().size()); auto b_ptr = AsDeviceMemory(filter.template flat().data(), filter.template flat().size()); auto c_ptr = AsDeviceMemory(output->template flat().data(), @@ -838,7 +845,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, } const bool compute_in_nhwc = ComputeInNhwcEnabled( - DataTypeToEnum::value, stream, /*use_4d_tensor=*/(spatial_dims == 2)); + DataTypeToEnum::value, stream, /*use_4d_tensor=*/(filter_dims.size() - 2 == 2)); const TensorFormat compute_data_format = (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC : FORMAT_NCHW; @@ -851,7 +858,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, for (int i = 0; i < output->dims(); ++i) { out_dims[i] = output->dim_size(i); } - std::vector> paddings(spatial_dims, {-1, -1}); + std::vector> paddings(filter_dims.size() - 2, {-1, -1}); // Explicit only on 2D case. if (padding == EXPLICIT) { GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', @@ -861,11 +868,12 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, } // Get padding values, output should be valid, since it was checked before. - std::vector out_dims_check(spatial_dims); - for (int i = 0; i < spatial_dims; ++i) { + std::vector out_dims_check(filter_dims.size() - 2); + for (int i = 0; i < filter_dims.size() - 2; ++i) { OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - in_dims[i], filter_dims[i], dilations[i], - strides[i], padding, &out_dims_check[i], + GetTensorDim(input_param, data_format, static_cast(i + '0')), filter_dims[i], + GetTensorDim(dilations, data_format, static_cast(i + '0')), + GetTensorDim(strides, data_format, static_cast(i + '0')), padding, &out_dims_check[i], &paddings[i].first, &paddings[i].second)); OP_REQUIRES(context, (out_dims_check[i] == GetTensorDim(*output, data_format, @@ -874,8 +882,8 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, } bool assymmetric_padding = false; - std::vector common_padding(spatial_dims); - for (int i = 0; i < spatial_dims; ++i) { + std::vector common_padding(filter_dims.size() - 2); + for (int i = 0; i < filter_dims.size() - 2; ++i) { common_padding[i] = std::min(paddings[i].first, paddings[i].second); assymmetric_padding = assymmetric_padding || (paddings[i].first != paddings[i].second); @@ -885,14 +893,14 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, // cuDNN only supports padding the same amount on either side. So we // manually create a new padded input tensor. Tensor transformed_input; - std::vector new_in_dims(input.dims()); + std::vector new_in_dims(input_param.dims()); new_in_dims[0] = in_batch; - for (int i = 0; i < spatial_dims; ++i) { - int index = GetTensorSpatialDimIndex(input.dims(), data_format, i); + for (int i = 0; i < filter_dims.size() - 2; ++i) { + int index = GetTensorSpatialDimIndex(input_param.dims(), data_format, i); new_in_dims[index] = - in_dims[i] + std::abs(paddings[i].first - paddings[i].second); + GetTensorDim(input_param, data_format, static_cast(i + '0')) + std::abs(paddings[i].first - paddings[i].second); } - new_in_dims[GetTensorDimIndex(data_format, 'C', input.dims())] = in_depth; + new_in_dims[GetTensorDimIndex(data_format, 'C', input_param.dims())] = in_depth; TensorShape transformed_input_shape(new_in_dims); OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, transformed_input_shape, @@ -901,14 +909,14 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, // Padding to add on transformed input. std::vector> transformed_input_padding( paddings); - for (int i = 0; i < spatial_dims; ++i) { + for (int i = 0; i < filter_dims.size() - 2; ++i) { transformed_input_padding[i].first -= common_padding[i]; transformed_input_padding[i].second -= common_padding[i]; } // Check padding size. bool padding_bounds_valid = true; - for (int i = 0; i < spatial_dims; ++i) { + for (int i = 0; i < filter_dims.size() - 2; ++i) { padding_bounds_valid = padding_bounds_valid && FastBoundsCheck(transformed_input_padding[i].first, @@ -920,7 +928,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, absl::InvalidArgumentError("Padding is too large.")); // Pad new input. - if (input.dims() == 4) { + if (input_param.dims() == 4) { std::array pad_left{ static_cast(transformed_input_padding[0].first), static_cast(transformed_input_padding[1].first)}; @@ -929,10 +937,10 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, static_cast(transformed_input_padding[1].second)}; functor::PadInput()( context->eigen_device(), - To32Bit(static_cast(input).tensor()), pad_left, + To32Bit(static_cast(input_param).tensor()), pad_left, pad_right, To32Bit(transformed_input.tensor()), data_format, T{}); - } else if (input.dims() == 5) { + } else if (input_param.dims() == 5) { std::array pad_left{ static_cast(transformed_input_padding[0].first), static_cast(transformed_input_padding[1].first), @@ -943,7 +951,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, static_cast(transformed_input_padding[2].second)}; functor::PadInput()( context->eigen_device(), - To32Bit(static_cast(input).tensor()), pad_left, + To32Bit(static_cast(input_param).tensor()), pad_left, pad_right, To32Bit(transformed_input.tensor()), data_format, T{}); } else { @@ -951,10 +959,10 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, absl::InternalError("Failed to pad input, invalid dimensions.")); } - input = transformed_input; - for (int i = 0; i < spatial_dims; ++i) { + input_param = transformed_input; + for (int i = 0; i < filter_dims.size() - 2; ++i) { in_dims[i] = new_in_dims[GetTensorDimIndex( - data_format, static_cast('0' + i), input.dims())]; + data_format, static_cast('0' + i), input_param.dims())]; } } @@ -971,25 +979,25 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, channels_first_shape, &transformed_input)); - if (input.dims() == 4) { + if (input_param.dims() == 4) { functor::NHWCToNCHW()( context->eigen_device(), - const_cast(input).tensor(), + const_cast(input_param).tensor(), transformed_input.tensor()); - } else if (input.dims() == 5) { + } else if (input_param.dims() == 5) { functor::NHWCToNCHW()( context->eigen_device(), - const_cast(input).tensor(), + const_cast(input_param).tensor(), transformed_input.tensor()); } else { context->SetStatus( absl::InternalError("Failed to reshape input to channels first " "format, invalid dimensions.")); } - input = transformed_input; + input_param = transformed_input; } else { // Depth = 1, reshape. - if (!input.CopyFrom(input, channels_first_shape)) { + if (!input_param.CopyFrom(input_param, channels_first_shape)) { context->SetStatus(absl::InternalError( "Failed to reshape input to channels first format.")); } @@ -1003,7 +1011,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, // Check paddings are not negative. bool non_negative_paddings = true; - for (int i = 0; i < spatial_dims; ++i) { + for (int i = 0; i < filter_dims.size() - 2; ++i) { non_negative_paddings = non_negative_paddings && common_padding[i] >= 0; } OP_REQUIRES(context, non_negative_paddings, @@ -1022,13 +1030,13 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, std::tie(compute_data_layout, filter_layout) = compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; - se::dnn::BatchDescriptor input_desc(spatial_dims); + se::dnn::BatchDescriptor input_desc(filter_dims.size() - 2); input_desc.set_count(in_batch).set_feature_map_count(in_depth).set_layout( compute_data_layout); - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { input_desc.set_spatial_dim(stream_executor::dnn::DimIndex::X, in_dims[1]) .set_spatial_dim(stream_executor::dnn::DimIndex::Y, in_dims[0]); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { input_desc.set_spatial_dim(stream_executor::dnn::DimIndex::X, in_dims[2]) .set_spatial_dim(stream_executor::dnn::DimIndex::Y, in_dims[1]) .set_spatial_dim(stream_executor::dnn::DimIndex::Z, in_dims[0]); @@ -1038,11 +1046,11 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, " invalid number of spatial dimensions")); } - se::dnn::BatchDescriptor output_desc(spatial_dims); + se::dnn::BatchDescriptor output_desc(filter_dims.size() - 2); output_desc.set_count(GetTensorDim(*output, data_format, 'N')) .set_feature_map_count(GetTensorDim(*output, data_format, 'C')) .set_layout(compute_data_layout); - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { output_desc .set_spatial_dim( stream_executor::dnn::DimIndex::X, @@ -1050,7 +1058,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, .set_spatial_dim( stream_executor::dnn::DimIndex::Y, GetTensorDim(*output, data_format, static_cast('0'))); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { output_desc .set_spatial_dim( stream_executor::dnn::DimIndex::X, @@ -1067,15 +1075,15 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, "number of spatial dimensions")); } - se::dnn::FilterDescriptor filter_desc(spatial_dims); + se::dnn::FilterDescriptor filter_desc(filter_dims.size() - 2); filter_desc.set_input_feature_map_count(filter_depth) .set_output_feature_map_count(out_depth) .set_layout(filter_layout); - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { filter_desc .set_spatial_dim(stream_executor::dnn::DimIndex::X, filter_dims[1]) .set_spatial_dim(stream_executor::dnn::DimIndex::Y, filter_dims[0]); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { filter_desc .set_spatial_dim(stream_executor::dnn::DimIndex::X, filter_dims[2]) .set_spatial_dim(stream_executor::dnn::DimIndex::Y, filter_dims[1]) @@ -1086,15 +1094,15 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, "number of spatial dimensions")); } - se::dnn::ConvolutionDescriptor conv_desc(spatial_dims); - if (spatial_dims == 2) { + se::dnn::ConvolutionDescriptor conv_desc(filter_dims.size() - 2); + if (filter_dims.size() - 2 == 2) { conv_desc.set_dilation_rate(stream_executor::dnn::DimIndex::X, dilations[1]) .set_dilation_rate(stream_executor::dnn::DimIndex::Y, dilations[0]) .set_filter_stride(stream_executor::dnn::DimIndex::X, strides[1]) .set_filter_stride(stream_executor::dnn::DimIndex::Y, strides[0]) .set_zero_padding(stream_executor::dnn::DimIndex::X, common_padding[1]) .set_zero_padding(stream_executor::dnn::DimIndex::Y, common_padding[0]); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { conv_desc.set_dilation_rate(stream_executor::dnn::DimIndex::X, dilations[2]) .set_dilation_rate(stream_executor::dnn::DimIndex::Y, dilations[1]) .set_dilation_rate(stream_executor::dnn::DimIndex::Z, dilations[0]) @@ -1116,7 +1124,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) << " to " << ToString(dst_format); - std::vector dst_shape_vec(spatial_dims + 2); + std::vector dst_shape_vec(filter_dims.size()); dst_shape_vec[0] = out_depth; if (dst_format == FORMAT_OIHW) { dst_shape_vec[1] = filter_depth; @@ -1138,12 +1146,12 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, // Filter: [(spatial_dims), in, out] (HWIO) // T_filter: [out, in, (spatial_dims)] (OIHW) or // T_filter: [out, (spatial_dims), in] (OHWI) - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { functor::TransformFilter()( context->eigen_device(), dst_format, To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { functor::TransformFilter()( context->eigen_device(), dst_format, To32Bit(filter.tensor()), @@ -1167,8 +1175,8 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, transformed_output = *output; } - auto input_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); + auto input_ptr = AsDeviceMemory(input_param.template flat().data(), + input_param.template flat().size()); auto filter_ptr = AsDeviceMemory(transformed_filter.template flat().data(), transformed_filter.template flat().size()); @@ -1178,7 +1186,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, static int64_t ConvolveScratchSize = GetDnnWorkspaceLimitOrDefault(); - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { filter_dims.push_back(filter_depth); } ConvParameters conv_parameters = { @@ -1192,7 +1200,7 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, dilations, // dilations strides, // strides common_padding, // paddings (symmetrical) - input.dtype(), // tensor datatype + input_param.dtype(), // tensor datatype conv_desc.group_count(), }; @@ -1216,12 +1224,12 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { VLOG(4) << "Convert the output tensor back from NCHW to NHWC."; - if (spatial_dims == 2) { + if (filter_dims.size() - 2 == 2) { functor::NCHWToNHWC()( context->eigen_device(), const_cast(transformed_output).tensor(), output->tensor()); - } else if (spatial_dims == 3) { + } else if (filter_dims.size() - 2 == 3) { functor::NCHWToNHWC()( context->eigen_device(), const_cast(transformed_output).tensor(),