Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
Expand Down
5 changes: 2 additions & 3 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,

// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations,
MatMulEnv& env) {
void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
Expand Down
2 changes: 2 additions & 0 deletions gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ namespace gcpp {
const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
void SumHeads(const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

Expand Down
11 changes: 8 additions & 3 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
if (old_d > 0.0f) {
const float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
} else {
scale = 0.0f;
x = hn::Zero(df);
}
return scale;
}

Expand Down
6 changes: 4 additions & 2 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
if (max_prompt_size > seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
max_prompt_size);
HWY_ABORT(
"max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
"that.",
max_prompt_size, seq_len);
}
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);

Expand Down
Loading