Skip to content

Commit 08a0760

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Internal changes
PiperOrigin-RevId: 846663686
1 parent b73a9ed commit 08a0760

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

gemma/flash_attention.cc

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ namespace gcpp {
5757
namespace HWY_NAMESPACE {
5858

5959
static constexpr size_t kNFx8HTileSize = 8;
60-
60+
static constexpr float kNegInf = -std::numeric_limits<float>::max() / 64.0f;
6161
// Transposes q into q_t.
6262
// Both are 4D tensors stuffed into a 2-D MatPtrT.
6363
// q has shape [batch, qbatch][head, qkv_dim].
@@ -467,7 +467,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
467467
const DF4 df4;
468468
using VF4 = hn::Vec<DF4>;
469469
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
470-
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
470+
VF4 new_max = hn::Set(df4, kNegInf);
471471
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
472472
max_0 = hn::Max(x_0_p0, x_0_p1);
473473
if constexpr (kNumQueries >= 2) {
@@ -490,38 +490,36 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
490490
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
491491
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
492492
}
493-
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
493+
VF4 old_max_vf = hn::Set(df4, kNegInf);
494494
old_max_vf = hn::LoadU(df4, old_max);
495495
new_max = hn::Max(new_max, old_max_vf);
496+
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
496497
// TODO figure out what was wrong with broadcasts and change to that.
497-
HWY_ALIGN float tmp_max[4];
498-
hn::Store(new_max, df4, tmp_max);
498+
hn::StoreU(new_max, df4, old_max);
499499
if constexpr (kNumQueries >= 1) {
500-
const VF new_max_0 = hn::Set(df, tmp_max[0]);
500+
const VF new_max_0 = hn::Set(df, old_max[0]);
501501
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0, new_max_0));
502502
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
503503
}
504504
if constexpr (kNumQueries >= 2) {
505-
const VF new_max_0 = hn::Set(df, tmp_max[1]);
505+
const VF new_max_0 = hn::Set(df, old_max[1]);
506506
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
507507
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
508508
}
509509
if constexpr (kNumQueries >= 3) {
510-
const VF new_max_0 = hn::Set(df, tmp_max[2]);
510+
const VF new_max_0 = hn::Set(df, old_max[2]);
511511
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
512512
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
513513
}
514514
if constexpr (kNumQueries >= 4) {
515-
const VF new_max_0 = hn::Set(df, tmp_max[3]);
515+
const VF new_max_0 = hn::Set(df, old_max[3]);
516516
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
517517
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
518518
}
519519
VF4 old_d_vf = hn::Set(df4, 0.0f);
520520
old_d_vf = hn::LoadU(df4, old_d);
521521
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
522522

523-
hn::StoreU(new_max, df4, old_max);
524-
525523
VF4 x_sum = hn::Zero(df4);
526524
if constexpr (kNumQueries == 1) {
527525
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
@@ -539,12 +537,12 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
539537
const VF4 zero4 = hn::Zero(df4);
540538
const VF4 one_over_d =
541539
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
542-
float tmp_one_over_d[4];
540+
HWY_ALIGN float tmp_one_over_d[4];
543541
hn::Store(one_over_d, df4, tmp_one_over_d);
544-
hn::Store(old_d_vf, df4, old_d);
542+
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
545543
scale = hn::Mul(scale, one_over_d);
546-
hn::Store(scale, df4, scales);
547-
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
544+
hn::BlendedStore(scale, changed_max, df4, scales);
545+
if (hn::ExtractLane(old_d_vf, 0) > 0.0f && scales[0] != 1.0f) {
548546
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
549547
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
550548
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
@@ -553,7 +551,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
553551
x_0_p1 = zero;
554552
}
555553
if constexpr (kNumQueries >= 2) {
556-
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
554+
if (hn::ExtractLane(old_d_vf, 1) > 0.0f && scales[1] != 1.0f) {
557555
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
558556
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
559557
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
@@ -563,7 +561,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
563561
}
564562
}
565563
if constexpr (kNumQueries >= 3) {
566-
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
564+
if (hn::ExtractLane(old_d_vf, 2) > 0.0f && scales[2] != 1.0f) {
567565
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
568566
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
569567
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
@@ -573,7 +571,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
573571
}
574572
}
575573
if constexpr (kNumQueries >= 4) {
576-
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
574+
if (hn::ExtractLane(old_d_vf, 3) > 0.0f && scales[3] != 1.0f) {
577575
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
578576
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
579577
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);

0 commit comments

Comments
 (0)