Skip to content

Commit 812ba68

Browse files
committed
feat(openchat): Mistral-7B inference engine (GQA + RoPE + RMSNorm + SiLU)
OpenChat 3.5 / Mistral-7B architecture, fully distinct from GPT-2: - GQA: 32 query heads share 8 KV heads (4:1 ratio, 75% KV cache savings) - RoPE: rotary positional embedding (no learned positions) - RMSNorm: simpler norm without mean subtraction (both in models::layers) - SiLU: gated MLP (gate * up → down) with F32x16 element-wise SIMD - GGUF weight loading via hpc::gguf (Q4_K_M + Q4_0 dequantization added) - CausalEdge64 emission from attention patterns - OpenChat chat template (GPT4 Correct User/Assistant markers) - /v1/chat/completions API types All ops through crate::simd::F32x16 via models::layers. No weights stored — loaded at runtime from user-provided GGUF. 15 tests passing. 77 total across new modules. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent 93e86a4 commit 812ba68

7 files changed

Lines changed: 1098 additions & 0 deletions

File tree

src/hpc/gguf.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ pub fn read_tensor_f32<R: Read + Seek>(
225225
GgmlType::Q8_0 => {
226226
dequantize_q8_0(reader, n_elements)
227227
}
228+
GgmlType::Q4_0 => {
229+
dequantize_q4_0(reader, n_elements)
230+
}
231+
GgmlType::Q4_K => {
232+
dequantize_q4_k(reader, n_elements)
233+
}
228234
other => Err(format!("Unsupported dtype for dequantization: {:?}", other)),
229235
}
230236
}
@@ -317,6 +323,90 @@ fn dequantize_q8_0<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, St
317323
Ok(result)
318324
}
319325

326+
/// Dequantize Q4_0: each block = 2 bytes scale (f16) + 16 bytes (32 nibbles).
327+
fn dequantize_q4_0<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, String> {
328+
let block_size = 32;
329+
let n_blocks = (n_elements + block_size - 1) / block_size;
330+
let mut result = Vec::with_capacity(n_elements);
331+
332+
for _ in 0..n_blocks {
333+
let mut scale_buf = [0u8; 2];
334+
r.read_exact(&mut scale_buf).map_err(|e| e.to_string())?;
335+
let scale = f16_to_f32(u16::from_le_bytes(scale_buf));
336+
337+
let mut nibbles = [0u8; 16];
338+
r.read_exact(&mut nibbles).map_err(|e| e.to_string())?;
339+
340+
for &byte in &nibbles {
341+
let lo = (byte & 0x0F) as i8 - 8;
342+
let hi = ((byte >> 4) & 0x0F) as i8 - 8;
343+
result.push(lo as f32 * scale);
344+
result.push(hi as f32 * scale);
345+
}
346+
}
347+
348+
result.truncate(n_elements);
349+
Ok(result)
350+
}
351+
352+
/// Dequantize Q4_K: super-blocks of 256 elements.
353+
///
354+
/// Q4_K block layout (144 bytes for 256 elements):
355+
/// - 2 bytes: d (f16 scale)
356+
/// - 2 bytes: dmin (f16 min)
357+
/// - 12 bytes: scales (6-bit per sub-block, packed)
358+
/// - 128 bytes: 256 4-bit quantized values (nibbles)
359+
fn dequantize_q4_k<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, String> {
360+
let block_size = 256;
361+
let n_blocks = (n_elements + block_size - 1) / block_size;
362+
let mut result = Vec::with_capacity(n_elements);
363+
364+
for _ in 0..n_blocks {
365+
// Read d and dmin (f16)
366+
let mut d_buf = [0u8; 2];
367+
let mut dmin_buf = [0u8; 2];
368+
r.read_exact(&mut d_buf).map_err(|e| e.to_string())?;
369+
r.read_exact(&mut dmin_buf).map_err(|e| e.to_string())?;
370+
let d = f16_to_f32(u16::from_le_bytes(d_buf));
371+
let dmin = f16_to_f32(u16::from_le_bytes(dmin_buf));
372+
373+
// Read scales (12 bytes = 8 sub-block scales + 8 sub-block mins, 6-bit packed)
374+
let mut scales_raw = [0u8; 12];
375+
r.read_exact(&mut scales_raw).map_err(|e| e.to_string())?;
376+
377+
// Decode 8 scale/min pairs from 12 bytes (6 bits each)
378+
let mut sc = [0u8; 8];
379+
let mut mn = [0u8; 8];
380+
for i in 0..4 {
381+
sc[i] = scales_raw[i] & 0x3F;
382+
mn[i] = scales_raw[i + 4] & 0x3F;
383+
sc[i + 4] = ((scales_raw[i + 8] & 0x0F) << 2) | (scales_raw[i] >> 6);
384+
mn[i + 4] = ((scales_raw[i + 8] >> 4) << 2) | (scales_raw[i + 4] >> 6);
385+
}
386+
387+
// Read 128 bytes of nibbles (256 4-bit values)
388+
let mut nibbles = [0u8; 128];
389+
r.read_exact(&mut nibbles).map_err(|e| e.to_string())?;
390+
391+
// Dequantize: each sub-block of 32 elements
392+
for j in 0..8 {
393+
let sub_d = d * sc[j] as f32;
394+
let sub_m = dmin * mn[j] as f32;
395+
let nib_offset = j * 16;
396+
for k in 0..16 {
397+
let byte = nibbles[nib_offset + k];
398+
let lo = (byte & 0x0F) as f32;
399+
let hi = ((byte >> 4) & 0x0F) as f32;
400+
result.push(lo * sub_d - sub_m);
401+
result.push(hi * sub_d - sub_m);
402+
}
403+
}
404+
}
405+
406+
result.truncate(n_elements);
407+
Ok(result)
408+
}
409+
320410
/// Convert f16 bit pattern to f32.
321411
fn f16_to_f32(bits: u16) -> f32 {
322412
let sign = ((bits >> 15) & 1) as u32;

src/hpc/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ pub mod gpt2;
182182
#[allow(missing_docs)]
183183
pub mod stable_diffusion;
184184

185+
/// OpenChat 3.5 inference — Mistral-7B architecture (GQA + RoPE + RMSNorm + SiLU).
186+
#[allow(missing_docs)]
187+
pub mod openchat;
188+
185189
// jitson: JSON config → scan pipeline (parser, validator, template, precompile, packed)
186190
// Always available — no Cranelift dependency.
187191
#[allow(missing_docs)]

src/hpc/models/layers.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,72 @@ pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
213213
sum
214214
}
215215

216+
/// RMS normalization (Mistral/Llama style): `x = x * weight / sqrt(mean(x²) + eps)`
217+
///
218+
/// No mean subtraction, no bias. Simpler and faster than LayerNorm.
219+
/// Used by OpenChat 3.5, Mistral, Llama 2/3.
220+
pub fn rms_norm(x: &mut [f32], weight: &[f32], eps: f32) {
221+
let n = x.len();
222+
let chunks = n / 16;
223+
224+
// Mean of squares (SIMD)
225+
let mut sq_acc = F32x16::splat(0.0);
226+
for c in 0..chunks {
227+
let off = c * 16;
228+
let v = F32x16::from_slice(&x[off..off + 16]);
229+
sq_acc = v.mul_add(v, sq_acc);
230+
}
231+
let mut mean_sq = sq_acc.reduce_sum();
232+
for i in (chunks * 16)..n {
233+
mean_sq += x[i] * x[i];
234+
}
235+
mean_sq /= n as f32;
236+
237+
let inv_rms = 1.0 / (mean_sq + eps).sqrt();
238+
let inv_rms_vec = F32x16::splat(inv_rms);
239+
240+
// Normalize × weight (SIMD)
241+
for c in 0..chunks {
242+
let off = c * 16;
243+
let v = F32x16::from_slice(&x[off..off + 16]);
244+
let w = F32x16::from_slice(&weight[off..off + 16]);
245+
let result = v * inv_rms_vec * w;
246+
result.copy_to_slice(&mut x[off..off + 16]);
247+
}
248+
for i in (chunks * 16)..n {
249+
x[i] = x[i] * inv_rms * weight[i];
250+
}
251+
}
252+
253+
/// Apply Rotary Positional Embedding (RoPE) to Q and K vectors.
254+
///
255+
/// Rotates pairs of dimensions by position-dependent angles:
256+
/// `(q[2i], q[2i+1]) = R(θ_i × pos) × (q[2i], q[2i+1])`
257+
/// where θ_i = 10000^(-2i/d).
258+
///
259+
/// Used by Mistral, Llama, OpenChat (replaces learned positional embeddings).
260+
pub fn rope_apply(q: &mut [f32], k: &mut [f32], head_dim: usize, position: usize, rope_theta: f32) {
261+
let half = head_dim / 2;
262+
for i in 0..half {
263+
let theta = rope_theta.powf(-(2.0 * i as f32) / head_dim as f32);
264+
let angle = position as f32 * theta;
265+
let cos_a = angle.cos();
266+
let sin_a = angle.sin();
267+
268+
// Apply to Q
269+
let q0 = q[2 * i];
270+
let q1 = q[2 * i + 1];
271+
q[2 * i] = q0 * cos_a - q1 * sin_a;
272+
q[2 * i + 1] = q0 * sin_a + q1 * cos_a;
273+
274+
// Apply to K
275+
let k0 = k[2 * i];
276+
let k1 = k[2 * i + 1];
277+
k[2 * i] = k0 * cos_a - k1 * sin_a;
278+
k[2 * i + 1] = k0 * sin_a + k1 * cos_a;
279+
}
280+
}
281+
216282
#[cfg(test)]
217283
mod tests {
218284
use super::*;
@@ -305,4 +371,62 @@ mod tests {
305371
assert!((output[0] - 3.0).abs() < 1e-5);
306372
assert!((output[1] - 7.0).abs() < 1e-5);
307373
}
374+
375+
#[test]
376+
fn test_rms_norm_unit_weight() {
377+
let mut x = vec![3.0, 4.0]; // rms = sqrt((9+16)/2) = sqrt(12.5) ≈ 3.536
378+
let w = vec![1.0; 2];
379+
rms_norm(&mut x, &w, 1e-5);
380+
let rms = (12.5f32).sqrt();
381+
assert!((x[0] - 3.0 / rms).abs() < 0.01);
382+
assert!((x[1] - 4.0 / rms).abs() < 0.01);
383+
}
384+
385+
#[test]
386+
fn test_rms_norm_scaling() {
387+
let mut x = vec![1.0, 1.0, 1.0, 1.0];
388+
let w = vec![2.0; 4];
389+
rms_norm(&mut x, &w, 1e-5);
390+
// rms = 1.0, so result = 1.0 * 2.0 = 2.0
391+
assert!((x[0] - 2.0).abs() < 0.01);
392+
}
393+
394+
#[test]
395+
fn test_rope_position_zero_identity() {
396+
let mut q = vec![1.0, 2.0, 3.0, 4.0];
397+
let mut k = vec![5.0, 6.0, 7.0, 8.0];
398+
let orig_q = q.clone();
399+
let orig_k = k.clone();
400+
rope_apply(&mut q, &mut k, 4, 0, 10000.0);
401+
// At position 0, angle = 0, cos=1, sin=0 → identity
402+
for i in 0..4 {
403+
assert!((q[i] - orig_q[i]).abs() < 1e-5);
404+
assert!((k[i] - orig_k[i]).abs() < 1e-5);
405+
}
406+
}
407+
408+
#[test]
409+
fn test_rope_changes_with_position() {
410+
let mut q1 = vec![1.0, 0.0, 1.0, 0.0];
411+
let mut k1 = vec![1.0, 0.0, 1.0, 0.0];
412+
let mut q2 = q1.clone();
413+
let mut k2 = k1.clone();
414+
rope_apply(&mut q1, &mut k1, 4, 1, 10000.0);
415+
rope_apply(&mut q2, &mut k2, 4, 100, 10000.0);
416+
// Different positions should give different results
417+
let diff: f32 = q1.iter().zip(&q2).map(|(a, b)| (a - b).abs()).sum();
418+
assert!(diff > 0.01, "different positions should produce different embeddings");
419+
}
420+
421+
#[test]
422+
fn test_rope_preserves_norm() {
423+
let mut q = vec![3.0, 4.0, 1.0, 2.0];
424+
let mut k = vec![0.0; 4];
425+
let norm_before: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
426+
rope_apply(&mut q, &mut k, 4, 42, 10000.0);
427+
let norm_after: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
428+
// RoPE is a rotation — should preserve L2 norm
429+
assert!((norm_before - norm_after).abs() < 0.01,
430+
"RoPE should preserve norm: {} vs {}", norm_before, norm_after);
431+
}
308432
}

0 commit comments

Comments
 (0)