Skip to content

Commit 2899fb8

Browse files
unamedkrclaude
andcommitted
fix(gemma4): numeric comparison with MLX-LM — divergence after layer 0
Layer-by-layer comparison with MLX-LM (google/gemma-4-E2B-it BF16): Embedding (BOS token 2): MLX: -1.6406, -1.5312, 0.1885, -1.4844 Ours: -1.6290, -1.5228, 0.1948, -1.4874 Diff: < 0.012 (Q5_0 vs BF16 quantization noise) ✅ Attn norm output (layer 0): MLX: -10.5625, -8.3125, 1.375, -12.1875 Ours: -10.4733, -8.3217, 1.4276, -12.2401 Diff: < 0.1 ✅ Q projection (layer 0): MLX: -4.375, 21.25, -0.797, 5.125 Ours: -4.306, 21.226, -0.711, 5.157 Diff: < 0.1 ✅ K projection (layer 0): MLX: 2.547, 3.141, -0.029, 1.133 Ours: 2.298, 3.182, 0.165, 1.169 Diff: < 0.25 (slightly larger but within Q8_0 tolerance) FINAL LOGITS (last position): MLX logits[100] (<|channel>): 22.88 (TOP-1) Ours logits[100]: -16.90 ← WRONG MLX logits[0:3]: -22.38, 7.09, -3.48 Ours logits[0:3]: -23.73, -2.68, 5.50 CONCLUSION: Embedding → attn_norm → Q/K projection are correct. Divergence happens INSIDE or AFTER the attention computation in layer 0, then compounds through 35 layers to produce completely wrong final logits (~40 logit difference on critical tokens). Next: compare attention output and FFN output at layer 0. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6ea3215 commit 2899fb8

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

quant.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14239,6 +14239,10 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
1423914239
} else {
1424014240
tq_matmul(s->q, s->xb, layer->wq, n_heads * head_dim, dim);
1424114241
}
14242+
if (pos == 0 && l == 0 && getenv("TQ_DEBUG")) {
14243+
fprintf(stderr, "[DEBUG] layer0 Q[0:4] = %.4f %.4f %.4f %.4f K[0:4] = ",
14244+
s->q[0], s->q[1], s->q[2], s->q[3]);
14245+
}
1424214246
}
1424314247
if (kv_shared_skip && kv_shared_ref_layer >= 0) {
1424414248
/* KV sharing: skip K/V projection for shared layers.
@@ -14285,6 +14289,9 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
1428514289
} else {
1428614290
tq_matmul(s->k, s->xb, layer->wk, kv_dim, dim);
1428714291
}
14292+
if (pos == 0 && l == 0 && getenv("TQ_DEBUG")) {
14293+
fprintf(stderr, "%.4f %.4f %.4f %.4f\n", s->k[0], s->k[1], s->k[2], s->k[3]);
14294+
}
1428814295
if (has_fused_qkv_layer) {
1428914296
/* skip — handled by the fused branch */
1429014297
} else {
@@ -15447,6 +15454,11 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
1544715454
/* Pre-attention/DeltaNet RMSNorm */
1544815455
tq_rmsnorm(s->xb, s->x, layer->attn_norm, dim, c->rms_norm_eps);
1544915456

15457+
if (pos == 0 && l == 0 && getenv("TQ_DEBUG")) {
15458+
fprintf(stderr, "[DEBUG] layer0 attn_norm_out[0:4] = %.4f %.4f %.4f %.4f\n",
15459+
s->xb[0], s->xb[1], s->xb[2], s->xb[3]);
15460+
}
15461+
1545015462
if (layer->delta_a_log) {
1545115463
/* DeltaNet layer */
1545215464
deltanet_forward(model, s, l);

0 commit comments

Comments
 (0)