Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 82 additions & 74 deletions tensorflow/core/kernels/conv_ops_impl.h

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class BlockLevelEmitterBackend : public GpuCodegenBackend {
const BackendConfig& config) override;

// Determines whether the given HLO instruction is supported by this backend.
bool IsSupported(const HloInstruction& instr);
bool IsSupported(const HloInstruction& instr) override;

// We don't want to use the Triton emitter as a reference because it can
// produce wrong results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,9 @@ SmallVector<Value> ReductionFusion::EvaluateEpilogue(
EmitterState& state, int group_id, ValueRange symbol_values) const {
ImplicitLocOpBuilder& b = state.builder;
const auto& epilogue = state.computations.epilogues()[group_id];
if (epilogue.roots.empty()) return outputs;
if (epilogue.roots.empty()) {
return outputs;
}

auto epilogue_input_indices = state.thread_and_block_ids;
epilogue_input_indices.append(symbol_values.begin(), symbol_values.end());
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ SmallVector<ValueRange> Unpack(ValueRange range, ArrayRef<int64_t> sizes) {
SmallVector<Value, 4> PadWithZeros(ValueRange values, int64_t size,
ImplicitLocOpBuilder& b) {
SmallVector<Value, 4> padded_values(values.begin(), values.end());
if (values.size() >= size) return padded_values;
if (values.size() >= size) {
return padded_values;
}
auto zero = arith::ConstantIndexOp::create(b, 0);
for (int i = values.size(); i < size; ++i) {
padded_values.push_back(zero);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern<arith::TruncFOp> {
}
if (v.getType() != f32_ty) {
return arith::TruncFOp::create(b, f32_ty, v);
} else {
return v;
}
return v;
});

mlir::StringAttr cvtIntr = b.getStringAttr(
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/backends/gpu/codegen/fusion_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class FusionInterface {
// Interface for fusions that are implemented using cuda kernels.
class KernelFusionInterface : public FusionInterface {
public:
virtual ~KernelFusionInterface() = default;
~KernelFusionInterface() override = default;

// Returns the fusion's launch dimensions.
virtual LaunchDimensions launch_dimensions() const = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ absl::Status TestBijection(const IndexingMap& map,
intervals.push_back({0, size - 1});
}
auto status = VerifyBijection(map, intervals);
if (status.ok()) return status;
if (status.ok()) {
return status;
}
return absl::FailedPreconditionError(
absl::StrCat(status.message(), " in map ", ToString(map)));
}
Expand Down Expand Up @@ -162,7 +164,9 @@ int main(int argc, char* argv[]) {
tsl::Flag(
"bijection_inputs",
[](std::string name_and_ids) {
if (name_and_ids.empty()) return false;
if (name_and_ids.empty()) {
return false;
}
flags.bijection_inputs.push_back(
xla::gpu::ParseHeroAndIds(name_and_ids));
return true;
Expand All @@ -174,7 +178,9 @@ int main(int argc, char* argv[]) {
tsl::Flag(
"bijection_outputs",
[](std::string name) {
if (name.empty()) return false;
if (name.empty()) {
return false;
}
flags.bijection_outputs.push_back(name);
return true;
},
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ xla_cc_test(
deps = [
":compilation_pipeline",
"//xla/stream_executor/cuda:cuda_compute_capability",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -48,12 +49,11 @@ TEST(CompilationPipelineTest, UnswitchLoopsAfterLICM) {
}
ASSERT_THAT(pass_names, Contains("LoopInvariantCodeMotion"));
ASSERT_THAT(pass_names, Contains("TritonXLAUnswitchLoopsPass"));
int licm_index = std::distance(pass_names.begin(),
std::find(pass_names.begin(), pass_names.end(),
"LoopInvariantCodeMotion"));
int unswitch_index = std::distance(
pass_names.begin(), std::find(pass_names.begin(), pass_names.end(),
"TritonXLAUnswitchLoopsPass"));
int licm_index = std::distance(
pass_names.begin(), absl::c_find(pass_names, "LoopInvariantCodeMotion"));
int unswitch_index =
std::distance(pass_names.begin(),
absl::c_find(pass_names, "TritonXLAUnswitchLoopsPass"));
// There is no hard requirement to run LICM **immediately** before the loop
// unswitcher but you should consider if the newly added pass might interact
// with the loop unswitcher.
Expand Down
31 changes: 14 additions & 17 deletions third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,20 @@ absl::StatusOr<std::optional<Type>> GetForceOperandsType(
// If there is a single allowed operand type, we force the operands to use
// this type.
return allowed_operands_types.front();

} else {
// If there are several allowed operand types, we just check that the
// operands have the same type, and that this type is one of the allowed
// ones. Raise an error otherwise.
if (lhs_type != rhs_type ||
!absl::c_linear_search(allowed_operands_types, lhs_type)) {
std::string allowed_operands_types_str = absl::StrJoin(
allowed_operands_types, ", ", [&](std::string* out, Type type) {
absl::StrAppend(out, MlirToString(type));
});
return absl::FailedPreconditionError(absl::StrCat(
"Expected dot operands to both have the same type, and for this type "
"to be one of the following types: ",
allowed_operands_types_str, " but got ", MlirToString(lhs_type),
" and ", MlirToString(rhs_type)));
}
} // If there are several allowed operand types, we just check that the
// operands have the same type, and that this type is one of the allowed
// ones. Raise an error otherwise.
if (lhs_type != rhs_type ||
!absl::c_linear_search(allowed_operands_types, lhs_type)) {
std::string allowed_operands_types_str = absl::StrJoin(
allowed_operands_types, ", ", [&](std::string* out, Type type) {
absl::StrAppend(out, MlirToString(type));
});
return absl::FailedPreconditionError(absl::StrCat(
"Expected dot operands to both have the same type, and for this type "
"to be one of the following types: ",
allowed_operands_types_str, " but got ", MlirToString(lhs_type),
" and ", MlirToString(rhs_type)));
}

return std::nullopt;
Expand Down
17 changes: 6 additions & 11 deletions third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) {
if (src_fp_element_ty.getFPMantissaWidth() >
dst_fp_element_ty.getFPMantissaWidth()) {
return ma::TruncFOp::create(b, dst_ty, value);
} else {
return ma::ExtFOp::create(b, dst_ty, value);
}
return ma::ExtFOp::create(b, dst_ty, value);
}
// int => int
if (mlir::isa<mlir::IntegerType>(src_element_ty) &&
Expand Down Expand Up @@ -321,16 +320,14 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) {
auto cst_int = [&](int64_t x) -> Value {
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape());
} else {
return CreateConst(b, dst_element_ty, x);
}
return CreateConst(b, dst_element_ty, x);
};
auto cst_float = [&](int64_t x) -> Value {
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape());
} else {
return CreateConst(b, src_fp_element_ty, x);
}
return CreateConst(b, src_fp_element_ty, x);
};
auto fptosi = ma::FPToSIOp::create(b, dst_ty, value);
int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth());
Expand Down Expand Up @@ -358,9 +355,8 @@ Value Cast(mlir::ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) {
Value Subtract(mlir::ImplicitLocOpBuilder& b, ValueRange values) {
if (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(values[0]))) {
return ma::SubIOp::create(b, values[0], values[1]);
} else {
return ma::SubFOp::create(b, values[0], values[1]);
}
return ma::SubFOp::create(b, values[0], values[1]);
}

Value Compare(mlir::ImplicitLocOpBuilder& b, ValueRange values,
Expand Down Expand Up @@ -560,10 +556,9 @@ absl::StatusOr<mlir::TypedValue<mlir::RankedTensorType>> EmitConstant(
if (constant.shape().element_type() == U64) {
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant, U64),
shape);
} else {
return CreateConst(b, ty, ScalarConstantValue<int64_t>(constant, S64),
shape);
}
return CreateConst(b, ty, ScalarConstantValue<int64_t>(constant, S64),
shape);
}
return CreateConst(b, ty, ScalarConstantValue<double>(constant, F64), shape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ static mlir::ParseResult parseI64ArrayAttr(mlir::AsmParser& parser,
mlir::DenseI64ArrayAttr& array) {
array = mlir::dyn_cast_or_null<mlir::DenseI64ArrayAttr>(
mlir::DenseI64ArrayAttr::parse(parser, mlir::Type{}));
if (!array) return mlir::failure();
if (!array) {
return mlir::failure();
}
return mlir::success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ CodegenDecision IsInstructionSupportsDataTypes(
const auto operand_type = operand->shape().element_type();
switch (instr.opcode()) {
case HloOpcode::kConvert:
if (operand_type == S4) continue;
if (operand_type == S4) {
continue;
}
[[fallthrough]];
default:
if (!IsTritonSupportedDataType(operand_type, gpu_version)) {
Expand Down Expand Up @@ -206,8 +208,9 @@ CodegenDecision CanTritonHandleElementwise(
}
if (instr.opcode() == HloOpcode::kConstant) {
return CodegenDecision::Allow();
} else if (!IsTritonSupportedElementwiseUpToFloatNormalization(
instr.opcode(), instr.operand(0)->shape().element_type())) {
}
if (!IsTritonSupportedElementwiseUpToFloatNormalization(
instr.opcode(), instr.operand(0)->shape().element_type())) {
return CodegenDecision::Forbid("Unsupported elementwise operation.");
}
return CodegenDecision::Allow();
Expand Down Expand Up @@ -360,7 +363,8 @@ CodegenDecision IsTritonSupportedDynamicSlice(
for (int i = 0; i < input->shape().dimensions().size(); ++i) {
if (i == majormost_dim_id) {
continue;
} else if (input->shape().dimensions(i) != instr.slice_sizes(i)) {
}
if (input->shape().dimensions(i) != instr.slice_sizes(i)) {
return CodegenDecision::Forbid(
"Unsupported dynamic slice on non-major-most dimension.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,9 @@ ENTRY e {
ApplyFloatNormalization(dot.Module().get(), GetComputeCapability()));
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(dot.Module()), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4}));
} else {
EXPECT_THAT(TritonFusionAnalysis::Execute(dot.TritonComputation()),
absl_testing::StatusIs(absl::StatusCode::kFailedPrecondition));
}
EXPECT_THAT(TritonFusionAnalysis::Execute(dot.TritonComputation()),
absl_testing::StatusIs(absl::StatusCode::kFailedPrecondition));
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
13 changes: 6 additions & 7 deletions third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) {
if (cc.IsCuda()) {
return cc.cuda_compute_capability()->IsAtLeast(
se::CudaComputeCapability::kAmpere);
} else if (cc.IsRocm()) {
}
if (cc.IsRocm()) {
return cc.rocm_compute_capability()->has_bf16_dtype_support();
}
CHECK(false);
Expand Down Expand Up @@ -248,10 +249,9 @@ std::string ComputeCapabilityToString(
const stream_executor::GpuComputeCapability& cc) {
if (auto* cuda_cc = cc.cuda_compute_capability()) {
return absl::StrReplaceAll(cuda_cc->ToString(), {{".", ""}});
} else {
CHECK(cc.IsRocm());
return "rocm";
}
CHECK(cc.IsRocm());
return "rocm";
}

std::string TritonSupportTestTypeAndDeviceToString(
Expand Down Expand Up @@ -329,10 +329,9 @@ absl::Status ConvertEntryToTritonFusion(HloModule* module,
gpu::GpuBackendConfig gpu_config;
if (use_nested_gemm_fusions) {
gpu_config.mutable_fusion_backend_config()->set_kind(
std::string(kTritonNestedGemmFusionKind));
kTritonNestedGemmFusionKind);
} else {
gpu_config.mutable_fusion_backend_config()->set_kind(
std::string(kTritonFusionKind));
gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind);
}
TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ void StripParameterAddressSpaces(RewriterBase& rewriter,
SmallVector<Type> generic_func_params(
llvm::map_range(func_ty.getParams(), [](Type type) -> Type {
auto ptr_ty = dyn_cast<LLVM::LLVMPointerType>(type);
if (!ptr_ty) return type;
if (!ptr_ty) {
return type;
}
if (ptr_ty.getAddressSpace() != NVVM::NVVMMemorySpace::Global) {
return type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@ class Tf32DotPattern : public OpRewritePattern<mt::DotOp> {
mlir::LogicalResult matchAndRewrite(
mt::DotOp op, PatternRewriter &rewriter) const override {
constexpr auto tf32_args_rounded = "tf32_arguments_rounded";
if (op.getInputPrecision() != mt::InputPrecision::TF32) return failure();
if (!op.getA().getType().getElementType().isF32()) return failure();
if (!op.getB().getType().getElementType().isF32()) return failure();
if (op->hasAttr(tf32_args_rounded)) return failure();
if (op.getInputPrecision() != mt::InputPrecision::TF32) {
return failure();
}
if (!op.getA().getType().getElementType().isF32()) {
return failure();
}
if (!op.getB().getType().getElementType().isF32()) {
return failure();
}
if (op->hasAttr(tf32_args_rounded)) {
return failure();
}

auto f32ToTF32 = [&](Value value) -> Value {
return ElementwiseInlineAsmOp::create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ LogicalResult LowerBlockBarrierOp(BlockBarrierOp block_barrier,
// Signal all ranks on the same block id.
mlir::triton::xla::AtomicWriteOp::create(
builder,
/*result_types=*/mlir::TypeRange{},
/*resultTypes=*/mlir::TypeRange{},
/*ptr=*/signal_addresses,
/*signal_value=*/signal_value,
/*mask=*/mlir::Value{},
Expand Down Expand Up @@ -173,7 +173,7 @@ LogicalResult LowerBlockBarrierOp(BlockBarrierOp block_barrier,
// Wait for all ranks on the same block id to signal.
mlir::triton::xla::AtomicSpinWaitOp::create(
builder,
/*result_types=*/mlir::TypeRange{},
/*resultTypes=*/mlir::TypeRange{},
/*ptr=*/wait_addresses,
/*expected=*/signal_value,
/*mask=*/mlir::Value{},
Expand Down
8 changes: 6 additions & 2 deletions third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ std::string GpuClique::DebugString() const {
num_communicators());
int32_t cnt = 0;
ForEachComm([&](RankId rank, Communicator* comm) {
if (cnt++) absl::StrAppend(&out, ", ");
if (cnt++) {
absl::StrAppend(&out, ", ");
}
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm);
});
return out;
Expand All @@ -63,7 +65,9 @@ absl::Status GpuClique::HealthCheck() const {
ForEachComm([&health_check](RankId rank, Communicator* comm) {
if (auto s = comm->HealthCheck(); !s.ok()) {
LOG(ERROR) << "GPU communicator error (rank " << rank << "): " << s;
if (health_check.ok()) health_check = std::move(s); // return first error
if (health_check.ok()) {
health_check = std::move(s); // return first error
}
}
});
return health_check;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ bool IsP2PStreamKind(AsyncStreamKind stream_kind) {
CollectiveStreamId GetCollectiveStreamId(bool is_async,
CollectiveStreamId stream_id,
AsyncStreamKind stream_kind) {
if (!is_async) return CollectiveStreamId(0);
if (!is_async) {
return CollectiveStreamId(0);
}
// TODO: Remove this fallback once AsyncStreamId is used everywhere.
if (stream_id.value() == 0)
if (stream_id.value() == 0) {
return CollectiveStreamId(static_cast<int64_t>(stream_kind) + 1);
}
return stream_id;
}

Expand Down
Loading