@@ -83,8 +83,8 @@ struct TestModelState {
8383 state.mat_owners , 43 );
8484 AllocateAndFillRandom (layer.gating_einsum_w , state.ctx .allocator ,
8585 state.mat_owners , 44 );
86- AllocateAndFillRandom (layer.linear_w , state.ctx .allocator ,
87- state. mat_owners , 45 );
86+ AllocateAndFillRandom (layer.linear_w , state.ctx .allocator , state. mat_owners ,
87+ 45 );
8888 layer.Fixup (state.mat_owners , state.ctx );
8989 }
9090
@@ -101,9 +101,10 @@ struct TestAttentionState {
101101 : num_tokens(num_tokens),
102102 qbatch_size (qbatch_size),
103103 batch_size(qbatch_size * num_tokens),
104+ runtime_config{.attention_impl = attention_impl},
104105 tokens (num_tokens),
105106 attention_storage_(model_state.config, model_state.layer_config,
106- batch_size, num_tokens, attention_impl ,
107+ batch_size, num_tokens, runtime_config ,
107108 state.ctx.allocator, row_ptrs_),
108109 attention(model_state.config, num_tokens, attention_storage_) {
109110 for (size_t i = 0 ; i < qbatch_size; ++i) {
@@ -276,8 +277,8 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
276277 -66.5 , -0.84765625 , -46.5 , -152 , -2.9375 , -81 }},
277278 {{3.984375 , 83 , -41.75 , 39.5 , -203 , 110 , -76 , 131 , 0.4609375 , -44.5 , -63.75 ,
278279 -46 , -22 , -19.375 , -16.125 , -148 , 20.875 },
279- {-47 , -19.5 , 58 , 81.5 , 21.75 , -30 , -118 , 44.25 , -149 , 22.5 , 188 , -66.5 ,
280- 33 , 10.9375 , -52.5 , 23.25 , 75 }},
280+ {-47 , -19.5 , 58 , 81.5 , 21.75 , -30 , -118 , 44.25 , -149 , 22.5 , 188 , -66.5 , 33 ,
281+ 10.9375 , -52.5 , 23.25 , 75 }},
281282 {{64 , -31 , -89 , -92.5 , -11.1875 , -54.75 , -302 , 3.453125 , -108 , 39.25 ,
282283 -34.75 , 18 , -52 , 100 , -186 , -75.5 , 50.75 },
283284 {7.6875 , -80 , -40 , 32.25 , -30.25 , 90 , -41 , 44.25 , -140 , -2.4375 , 82.5 ,
@@ -366,10 +367,9 @@ const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
366367 -4.42512083 , 1.78077614 , -3.25167561 , 0.864362717 , 0.474019766 ,
367368 -7.92327404 , -2.27795148 , -0.436354101 , -3.15722394 , 0.415780187 ,
368369 2.60931611 }},
369- {{-9.43858051 , 0.391518891 , -2.74012518 , 4.9842453 , 7.48263216 ,
370- -16.3434925 , -4.75156116 , -1.99114823 , 3.99918842 , -5.95400572 ,
371- 10.8700314 , 1.07596064 , 0.30389142 , 8.39548779 , -5.11913681 , 5.45641088 ,
372- -5.63240337 },
370+ {{-9.43858051 , 0.391518891 , -2.74012518 , 4.9842453 , 7.48263216 , -16.3434925 ,
371+ -4.75156116 , -1.99114823 , 3.99918842 , -5.95400572 , 10.8700314 , 1.07596064 ,
372+ 0.30389142 , 8.39548779 , -5.11913681 , 5.45641088 , -5.63240337 },
373373 {-1.22347319 , 9.57339382 , -1.31736016 , -5.02770805 , -4.81617355 ,
374374 -1.96618557 , -0.456317186 , 12.6451035 , -1.50221801 , 6.7991147 ,
375375 -5.97842169 , 1.85410941 , -8.44729 , 0.378282309 , 0.0442156792 , 17.6773052 ,
@@ -381,14 +381,12 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
381381 {{2.77553034 , -7.67514181 , -1.60433948 , 4.67795134 , -1.75084186 , 8.57896423 ,
382382 -1.15065813 , -3.75088787 , -4.7442131 , -1.68890858 , -10.0202332 ,
383383 -4.20167446 , 9.36844635 , 13.7364845 , 11.5634 , 2.95288706 , 2.89380026 },
384- {-4.79950905 , -1.66658688 , 4.14471292 , -4.95649052 , -5.4200325 ,
385- 3.52626801 , -10.9432049 , 0.338347554 , -1.53204226 , 0.473476171 , -0.58271 ,
386- 1.42195463 , 0.301399827 , -4.40214968 , -2.12298298 , 9.27825642 ,
387- -0.690600872 }},
384+ {-4.79950905 , -1.66658688 , 4.14471292 , -4.95649052 , -5.4200325 , 3.52626801 ,
385+ -10.9432049 , 0.338347554 , -1.53204226 , 0.473476171 , -0.58271 , 1.42195463 ,
386+ 0.301399827 , -4.40214968 , -2.12298298 , 9.27825642 , -0.690600872 }},
388387 {{-10.6566734 , 4.12785721 , 4.54053593 , -1.39667869 , -1.55028772 , 0.20508635 ,
389- -0.00620913506 , 2.93214 , -0.788117647 , 2.78032446 , -2.68898249 ,
390- 9.5985508 , -10.6630878 , -11.9006901 , 0.851743698 , 0.581826329 ,
391- 5.21927929 },
388+ -0.00620913506 , 2.93214 , -0.788117647 , 2.78032446 , -2.68898249 , 9.5985508 ,
389+ -10.6630878 , -11.9006901 , 0.851743698 , 0.581826329 , 5.21927929 },
392390 {-0.322291255 , 2.63848567 , -2.30808377 , -13.0153809 , 2.74378228 ,
393391 3.21460533 , 0.688529968 , 2.37544608 , 6.06825066 , 4.57566404 , 1.17124248 ,
394392 -7.96587658 , -2.65279341 , 4.75271225 , -4.09937954 , -10.3570251 ,
@@ -411,13 +409,11 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
411409 -7.11484337 , 2.53943753 , -0.652261257 , 9.77392 , 3.53345847 , -9.62052822 ,
412410 16.0471916 },
413411 {6.89768124 , 2.36394405 , -2.08569574 , -0.682706833 , 3.38872 , -6.28313875 ,
414- 4.79594612 , 4.93417454 , -6.40791416 , -10.7355442 , -5.66094208 ,
415- 2.44881392 , 1.99794042 , -9.19855404 , -4.02383137 , -3.63013959 ,
416- -5.65853405 }},
417- {{1.64614546 , -3.93421197 , -0.48935914 , 5.48284435 , -7.69781828 ,
418- 11.8203125 , 1.81672478 , -1.42535269 , -5.26496315 , -5.31612349 ,
419- -4.19499826 , 7.06049395 , 0.18029356 , -0.0519902706 , 10.317358 , 2.19345617 ,
420- 3.5296216 },
412+ 4.79594612 , 4.93417454 , -6.40791416 , -10.7355442 , -5.66094208 , 2.44881392 ,
413+ 1.99794042 , -9.19855404 , -4.02383137 , -3.63013959 , -5.65853405 }},
414+ {{1.64614546 , -3.93421197 , -0.48935914 , 5.48284435 , -7.69781828 , 11.8203125 ,
415+ 1.81672478 , -1.42535269 , -5.26496315 , -5.31612349 , -4.19499826 ,
416+ 7.06049395 , 0.18029356 , -0.0519902706 , 10.317358 , 2.19345617 , 3.5296216 },
421417 {7.52353811 , 3.56836724 , 0.414305687 , 0.340799928 , 2.44263697 , 7.52111912 ,
422418 0.246491909 , -11.1172791 , -3.82061529 , 3.24794388 , 0.751524329 ,
423419 3.14019632 , 6.33881855 , -0.169233799 , 7.82640171 , 1.5389179 , 8.15851307 }},
0 commit comments