Skip to content

Commit 49d420a

Browse files
stollemcopybara-github
authored andcommitted
Add some comments.
PiperOrigin-RevId: 834173319
1 parent b8f6be7 commit 49d420a

File tree

4 files changed

+57
-3
lines changed

4 files changed

+57
-3
lines changed

gemma/activations.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,32 @@ struct AttentionActivationsPtrs {
165165
}
166166

167167
const ModelConfig& config;
168+
// Query matrix of size batch_size x (q_heads * qkv_dim).
168169
MatPtrT<float> q;
170+
// Query matrix of size batch_size x (q_heads * qkv_dim).
169171
MatPtrT<BF16> q_bf;
172+
// Transposed query matrix for faster Q*K^T.
170173
MatPtrT<BF16> q_T;
174+
// Output of RMSNorm before attention, size batch_size x model_dim.
171175
MatPtrT<float> pre_att_rms_out;
176+
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
177+
// seq_len).
172178
MatPtrT<float> att;
179+
// Attention output computed from att * V, size batch_size x (q_heads *
180+
// qkv_dim).
173181
MatPtrT<float> att_out;
182+
// Accumulation of attention outputs over heads, size batch_size x
183+
// model_dim.
174184
MatPtrT<BF16> att_sums;
185+
// Inverse timescales for RoPE computation.
175186
MatPtrT<float> inv_timescale;
187+
// Inverse timescales for global RoPE computation.
176188
MatPtrT<float> inv_timescale_global;
189+
// Divisor for faster division by sequence length.
177190
hwy::Divisor div_seq_len;
191+
// Divisor for faster division by number of heads.
178192
hwy::Divisor div_heads;
193+
// Query scaling factor for attention computation.
179194
float query_scale;
180195
};
181196

gemma/flash_attention.cc

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,37 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
428428
return scale;
429429
}
430430

431-
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
432-
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
433-
// max_last_pos].
431+
// Implements flash attention for a strip of 4 query vectors.
432+
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
433+
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
434+
// by NF timesteps in K for efficiency while timesteps between `min_last_pos +
435+
// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
436+
// values within the strip.
437+
// (*) Actually, it only iterates through
438+
// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
439+
// computation can, for obvious reasons, only process an integer number of
440+
// tiles.
441+
//
442+
// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
443+
// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
444+
// vectors to be processed in this tile.
445+
// @param k Key matrix [seq_len, qkv_dim] from KV cache.
446+
// @param start_pos The first token position in the KV cache to attend to.
447+
// @param last_pos An array of 4 indices giving the last token position
448+
// (inclusive) that each of the 4 queries may attend to.
449+
// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
450+
// position can be processed efficiently in batches.
451+
// @param max_last_pos The maximum value in `last_pos`. Timesteps between
452+
// `min_last_pos + 1` and this position are processed individually to
453+
// respect each query's `last_pos` limit.
454+
// @param v Value matrix [seq_len, qkv_dim] from KV cache.
455+
// @param layer_idx The index of the current transformer layer.
456+
// @param activations Attention configurations and buffers.
457+
// @param att_out Output buffer for attention results.
458+
// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
459+
// vectors.
460+
// @param ctx Threading context.
461+
// @param worker Worker thread index.
434462
Tile4FlashState TileFlashAttention4(
435463
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
436464
const MatPtrT<KV_t>& k, const size_t start_pos,

gemma/flash_structs.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@
77

88
namespace gcpp {
99

10+
// State for computing softmax in a streaming ("online") manner,
11+
// avoiding large intermediate values by subtracting the running maximum.
12+
// For a sequence x_1, ..., x_n:
13+
// m_i = max(m_{i-1}, x_i)
14+
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
15+
// softmax_i = exp(x_i - m_i) / d_i
1016
struct OnlineSoftmaxState {
17+
// Maximum logit value encountered so far.
1118
float max = -std::numeric_limits<float>::max() / 2.0f;
19+
// Sum of exponentials scaled by exp(-max).
1220
float d = 0.0f;
1321
};
1422

ops/ops-inl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ namespace gcpp {
6161
namespace HWY_NAMESPACE {
6262
namespace hn = hwy::HWY_NAMESPACE;
6363

64+
// Computes C = A * B + add via MatMulStatic.
65+
// This function uses CallUpcasted to dispatch to the correct MatMulStatic
66+
// instantiation based on the runtime type of B.
6467
template <typename TA, typename TC>
6568
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
6669
const float* HWY_RESTRICT add, MatMulEnv& env,

0 commit comments

Comments
 (0)