@@ -57,7 +57,7 @@ namespace gcpp {
5757namespace HWY_NAMESPACE {
5858
5959static constexpr size_t kNFx8HTileSize = 8 ;
60-
60+ static constexpr float kNegInf = -std::numeric_limits< float >::max() / 64 . 0f ;
6161// Transposes q into q_t.
6262// Both are 4D tensors stuffed into a 2-D MatPtrT.
6363// q has shape [batch, qbatch][head, qkv_dim].
@@ -467,7 +467,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
467467 const DF4 df4;
468468 using VF4 = hn::Vec<DF4>;
469469 static_assert (kNumQueries >= 1 && kNumQueries <= 4 );
470- VF4 new_max = hn::Set (df4, -std::numeric_limits< float >:: max () / 2 . 0f );
470+ VF4 new_max = hn::Set (df4, kNegInf );
471471 VF max_0, max_1, max_2, max_3 = hn::Zero (df);
472472 max_0 = hn::Max (x_0_p0, x_0_p1);
473473 if constexpr (kNumQueries >= 2 ) {
@@ -490,38 +490,36 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
490490 VF4 one_over_cap = hn::Set (df4, one_over_att_cap);
491491 new_max = hn::Mul (cap, hn::Tanh (df4, hn::Mul (new_max, one_over_cap)));
492492 }
493- VF4 old_max_vf = hn::Set (df4, -std::numeric_limits< float >:: max () / 2 . 0f );
493+ VF4 old_max_vf = hn::Set (df4, kNegInf );
494494 old_max_vf = hn::LoadU (df4, old_max);
495495 new_max = hn::Max (new_max, old_max_vf);
496+ auto changed_max = hn::Gt (new_max, hn::Set (df4, kNegInf ));
496497 // TODO figure out what was wrong with broadcasts and change to that.
497- HWY_ALIGN float tmp_max[4 ];
498- hn::Store (new_max, df4, tmp_max);
498+ hn::StoreU (new_max, df4, old_max);
499499 if constexpr (kNumQueries >= 1 ) {
500- const VF new_max_0 = hn::Set (df, tmp_max [0 ]);
500+ const VF new_max_0 = hn::Set (df, old_max [0 ]);
501501 x_0_p0 = hn::Exp (df, hn::Sub (x_0_p0, new_max_0));
502502 x_0_p1 = hn::Exp (df, hn::Sub (x_0_p1, new_max_0));
503503 }
504504 if constexpr (kNumQueries >= 2 ) {
505- const VF new_max_0 = hn::Set (df, tmp_max [1 ]);
505+ const VF new_max_0 = hn::Set (df, old_max [1 ]);
506506 x_1_p0 = hn::Exp (df, hn::Sub (x_1_p0, new_max_0));
507507 x_1_p1 = hn::Exp (df, hn::Sub (x_1_p1, new_max_0));
508508 }
509509 if constexpr (kNumQueries >= 3 ) {
510- const VF new_max_0 = hn::Set (df, tmp_max [2 ]);
510+ const VF new_max_0 = hn::Set (df, old_max [2 ]);
511511 x_2_p0 = hn::Exp (df, hn::Sub (x_2_p0, new_max_0));
512512 x_2_p1 = hn::Exp (df, hn::Sub (x_2_p1, new_max_0));
513513 }
514514 if constexpr (kNumQueries >= 4 ) {
515- const VF new_max_0 = hn::Set (df, tmp_max [3 ]);
515+ const VF new_max_0 = hn::Set (df, old_max [3 ]);
516516 x_3_p0 = hn::Exp (df, hn::Sub (x_3_p0, new_max_0));
517517 x_3_p1 = hn::Exp (df, hn::Sub (x_3_p1, new_max_0));
518518 }
519519 VF4 old_d_vf = hn::Set (df4, 0 .0f );
520520 old_d_vf = hn::LoadU (df4, old_d);
521521 VF4 scale = hn::Mul (old_d_vf, hn::Exp (df4, hn::Sub (old_max_vf, new_max)));
522522
523- hn::StoreU (new_max, df4, old_max);
524-
525523 VF4 x_sum = hn::Zero (df4);
526524 if constexpr (kNumQueries == 1 ) {
527525 x_sum = hn::Set (df4, hn::ReduceSum (df, x_0_p0) + hn::ReduceSum (df, x_0_p1));
@@ -539,12 +537,12 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
539537 const VF4 zero4 = hn::Zero (df4);
540538 const VF4 one_over_d =
541539 hn::MaskedDivOr (zero4, non_zero_mask, hn::Set (df4, 1 .0f ), old_d_vf);
542- float tmp_one_over_d[4 ];
540+ HWY_ALIGN float tmp_one_over_d[4 ];
543541 hn::Store (one_over_d, df4, tmp_one_over_d);
544- hn::Store (old_d_vf, df4, old_d);
542+ hn::BlendedStore (old_d_vf, changed_max , df4, old_d);
545543 scale = hn::Mul (scale, one_over_d);
546- hn::Store (scale, df4, scales);
547- if (hn::ExtractLane (old_d_vf, 0 ) > 0 .0f ) {
544+ hn::BlendedStore (scale, changed_max , df4, scales);
545+ if (hn::ExtractLane (old_d_vf, 0 ) > 0 .0f && scales[ 0 ] != 1 . 0f ) {
548546 const VF one_over_d_0 = hn::Set (df, tmp_one_over_d[0 ]);
549547 x_0_p0 = hn::Mul (x_0_p0, one_over_d_0);
550548 x_0_p1 = hn::Mul (x_0_p1, one_over_d_0);
@@ -553,7 +551,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
553551 x_0_p1 = zero;
554552 }
555553 if constexpr (kNumQueries >= 2 ) {
556- if (hn::ExtractLane (old_d_vf, 1 ) > 0 .0f ) {
554+ if (hn::ExtractLane (old_d_vf, 1 ) > 0 .0f && scales[ 1 ] != 1 . 0f ) {
557555 const VF one_over_d_1 = hn::Set (df, tmp_one_over_d[1 ]);
558556 x_1_p0 = hn::Mul (x_1_p0, one_over_d_1);
559557 x_1_p1 = hn::Mul (x_1_p1, one_over_d_1);
@@ -563,7 +561,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
563561 }
564562 }
565563 if constexpr (kNumQueries >= 3 ) {
566- if (hn::ExtractLane (old_d_vf, 2 ) > 0 .0f ) {
564+ if (hn::ExtractLane (old_d_vf, 2 ) > 0 .0f && scales[ 2 ] != 1 . 0f ) {
567565 const VF one_over_d_2 = hn::Set (df, tmp_one_over_d[2 ]);
568566 x_2_p0 = hn::Mul (x_2_p0, one_over_d_2);
569567 x_2_p1 = hn::Mul (x_2_p1, one_over_d_2);
@@ -573,7 +571,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
573571 }
574572 }
575573 if constexpr (kNumQueries >= 4 ) {
576- if (hn::ExtractLane (old_d_vf, 3 ) > 0 .0f ) {
574+ if (hn::ExtractLane (old_d_vf, 3 ) > 0 .0f && scales[ 3 ] != 1 . 0f ) {
577575 const VF one_over_d_3 = hn::Set (df, tmp_one_over_d[3 ]);
578576 x_3_p0 = hn::Mul (x_3_p0, one_over_d_3);
579577 x_3_p1 = hn::Mul (x_3_p1, one_over_d_3);
0 commit comments