Skip to content

Commit baa69df

Browse files
Balazs Raczcopybara-github
authored andcommitted
Makes the entire runtime_config passed into the activations constructor.
PiperOrigin-RevId: 845153671
1 parent 44dfd69 commit baa69df

File tree

4 files changed

+28
-33
lines changed

4 files changed

+28
-33
lines changed

BUILD.bazel

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ cc_test(
141141
":kv_cache",
142142
":mat",
143143
":matmul",
144-
":query",
145144
":test_util",
146145
":threading_context",
147146
":weights",
@@ -643,7 +642,6 @@ cc_test(
643642
":kv_cache",
644643
":mat",
645644
":matmul_env",
646-
":query",
647645
":test_util",
648646
":threading_context",
649647
":weights",

gemma/activations.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4848
struct AttentionActivations {
4949
AttentionActivations(
5050
const ModelConfig& config, const LayerConfig& layer_config,
51-
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
51+
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
5252
const Allocator& allocator,
5353
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
5454
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
@@ -129,7 +129,7 @@ struct AttentionActivations {
129129
// `inv_timescale*` are not batched.
130130
}
131131

132-
MatStorageT<float> q; // query
132+
MatStorageT<float> q; // query
133133
MatStorageT<BF16> q_bf;
134134
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
135135

@@ -138,8 +138,8 @@ struct AttentionActivations {
138138
MatStorageT<float> vit_C;
139139

140140
MatStorageT<float> pre_att_rms_out;
141-
MatStorageT<float> att; // attention vector
142-
MatStorageT<float> att_out; // attention output
141+
MatStorageT<float> att; // attention vector
142+
MatStorageT<float> att_out; // attention output
143143
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
144144
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
145145
// Accumulation of attention outputs over heads
@@ -279,8 +279,7 @@ struct Activations {
279279
s_w_linear_w(config.num_layers, max_workers),
280280
attention_impl(runtime_config.attention_impl),
281281
attention_storage(config, layer_config, batch_size, seq_len,
282-
runtime_config.attention_impl, ctx.allocator,
283-
row_ptrs),
282+
runtime_config, ctx.allocator, row_ptrs),
284283
attention(config, seq_len, attention_storage) {
285284
HWY_ASSERT(batch_size != 0);
286285

gemma/attention_test.cc

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ struct TestModelState {
8383
state.mat_owners, 43);
8484
AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator,
8585
state.mat_owners, 44);
86-
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator,
87-
state.mat_owners, 45);
86+
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners,
87+
45);
8888
layer.Fixup(state.mat_owners, state.ctx);
8989
}
9090

@@ -101,9 +101,10 @@ struct TestAttentionState {
101101
: num_tokens(num_tokens),
102102
qbatch_size(qbatch_size),
103103
batch_size(qbatch_size * num_tokens),
104+
runtime_config{.attention_impl = attention_impl},
104105
tokens(num_tokens),
105106
attention_storage_(model_state.config, model_state.layer_config,
106-
batch_size, num_tokens, attention_impl,
107+
batch_size, num_tokens, runtime_config,
107108
state.ctx.allocator, row_ptrs_),
108109
attention(model_state.config, num_tokens, attention_storage_) {
109110
for (size_t i = 0; i < qbatch_size; ++i) {
@@ -276,8 +277,8 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
276277
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
277278
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
278279
-46, -22, -19.375, -16.125, -148, 20.875},
279-
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5,
280-
33, 10.9375, -52.5, 23.25, 75}},
280+
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
281+
10.9375, -52.5, 23.25, 75}},
281282
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
282283
-34.75, 18, -52, 100, -186, -75.5, 50.75},
283284
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
@@ -366,10 +367,9 @@ const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
366367
-4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766,
367368
-7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187,
368369
2.60931611}},
369-
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216,
370-
-16.3434925, -4.75156116, -1.99114823, 3.99918842, -5.95400572,
371-
10.8700314, 1.07596064, 0.30389142, 8.39548779, -5.11913681, 5.45641088,
372-
-5.63240337},
370+
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925,
371+
-4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064,
372+
0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337},
373373
{-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355,
374374
-1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147,
375375
-5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052,
@@ -381,14 +381,12 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
381381
{{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423,
382382
-1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332,
383383
-4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026},
384-
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325,
385-
3.52626801, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271,
386-
1.42195463, 0.301399827, -4.40214968, -2.12298298, 9.27825642,
387-
-0.690600872}},
384+
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801,
385+
-10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463,
386+
0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}},
388387
{{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635,
389-
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249,
390-
9.5985508, -10.6630878, -11.9006901, 0.851743698, 0.581826329,
391-
5.21927929},
388+
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508,
389+
-10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929},
392390
{-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228,
393391
3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248,
394392
-7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251,
@@ -411,13 +409,11 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
411409
-7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822,
412410
16.0471916},
413411
{6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875,
414-
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208,
415-
2.44881392, 1.99794042, -9.19855404, -4.02383137, -3.63013959,
416-
-5.65853405}},
417-
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828,
418-
11.8203125, 1.81672478, -1.42535269, -5.26496315, -5.31612349,
419-
-4.19499826, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617,
420-
3.5296216},
412+
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392,
413+
1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}},
414+
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125,
415+
1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826,
416+
7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216},
421417
{7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912,
422418
0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329,
423419
3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}},

gemma/flash_attention_test.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
112112
const LayerConfig& layer_config = config.layer_configs[0];
113113
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
114114
InferenceArgs inference_args;
115+
inference_args.attention_impl = "flash";
115116
RuntimeConfig runtime_config;
117+
inference_args.CopyTo(runtime_config);
116118
KVCache kv_cache(config, inference_args, ctx.allocator);
117119
MatMulEnv env(ctx);
118120
Activations activations(runtime_config, config,
@@ -127,8 +129,8 @@ void TestFlashAttention(size_t target_parallelism) {
127129
const size_t batch_size = kOuter;
128130
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
129131
AttentionActivations attention_storage(config, layer_config, batch_size,
130-
kOuter, AttentionImpl::kFlash,
131-
ctx.allocator, row_ptrs);
132+
kOuter, runtime_config, ctx.allocator,
133+
row_ptrs);
132134
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
133135
const size_t qkv_dim = layer_config.qkv_dim;
134136
ASSERT_EQ(qkv_dim, kInner);

0 commit comments

Comments
 (0)