@@ -428,9 +428,37 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
428428 return scale;
429429}
430430
431- // Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
432- // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
433- // max_last_pos].
431+ // Implements flash attention for a strip of 4 query vectors.
432+ // It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
433+ // Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
434+ // by NF timesteps in K for efficiency while timesteps between `min_last_pos +
435+ // 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
436+ // values within the strip.
437+ // (*) Actually, it only iterates through
438+ // `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
439+ // computation can, for obvious reasons, only process an integer number of
440+ // tiles.
441+ //
442+ // @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
443+ // @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
444+ // vectors to be processed in this tile.
445+ // @param k Key matrix [seq_len, qkv_dim] from KV cache.
446+ // @param start_pos The first token position in the KV cache to attend to.
447+ // @param last_pos An array of 4 indices giving the last token position
448+ // (inclusive) that each of the 4 queries may attend to.
449+ // @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
450+ // position can be processed efficiently in batches.
451+ // @param max_last_pos The maximum value in `last_pos`. Timesteps between
452+ // `min_last_pos + 1` and this position are processed individually to
453+ // respect each query's `last_pos` limit.
454+ // @param v Value matrix [seq_len, qkv_dim] from KV cache.
455+ // @param layer_idx The index of the current transformer layer.
456+ // @param activations Attention configurations and buffers.
457+ // @param att_out Output buffer for attention results.
458+ // @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
459+ // vectors.
460+ // @param ctx Threading context.
461+ // @param worker Worker thread index.
434462Tile4FlashState TileFlashAttention4 (
435463 const MatPtrT<BF16>& q, const uint32_t * HWY_RESTRICT q_offsets,
436464 const MatPtrT<KV_t>& k, const size_t start_pos,
0 commit comments