Skip to content

Commit 376aacb

Browse files
committed
feat(quantized): Q4_0 GGUF-compat quant helpers (sprint A5)
Add quantize_f32_to_q4_0 / dequantize_q4_0_to_f32 implementing the GGUF / llama.cpp per-32-element block scheme: 16 packed bytes plus one f32 scale d = max_signed/-8 per block, with the canonical interleaved nibble layout (element j -> low nibble of byte j; element j+16 -> high nibble of byte j). The existing per-tensor quantize_f32_to_i4 (low-nibble-first, non-interleaved, scale = abs_max/7) is preserved unchanged for backwards compatibility. Burn QuantValue::Q4F / Q4S callers can opt into either scheme. Tests: i4 boundary +/-7 and clamp +/-8; Q4_0 single-block, multi-block, zero-block, interleaved layout, non-aligned panic. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
1 parent 44c0845 commit 376aacb

1 file changed

Lines changed: 211 additions & 0 deletions

File tree

src/hpc/quantized.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,112 @@ pub fn dequantize_i2_to_f32(packed: &[u8], params: &QuantParams, n: usize) -> Ve
463463
out
464464
}
465465

466+
// ── Q4_0 (GGUF-compatible block quantization) ─────────────────────
467+
468+
/// Q4_0 block size (number of f32 elements per block).
469+
pub const Q4_0_BLOCK_SIZE: usize = 32;
470+
471+
/// Number of bytes used to pack one Q4_0 block (32 nibbles = 16 bytes).
472+
pub const Q4_0_BYTES_PER_BLOCK: usize = Q4_0_BLOCK_SIZE / 2;
473+
474+
/// Quantize f32 to Q4_0 (GGUF-compatible) per-block 4-bit quantization.
475+
///
476+
/// Each block of [`Q4_0_BLOCK_SIZE`] (32) f32 elements is encoded as
477+
/// [`Q4_0_BYTES_PER_BLOCK`] (16) packed bytes plus one f32 scale `d`.
478+
///
479+
/// Encoding (matches `llama.cpp` / GGUF reference):
480+
/// - `d = max(|x_i|) / -8.0` (negated max-abs, signed)
481+
/// - `q_i = clamp(round(x_i / d) + 8, 0, 15)` (unsigned nibble in `0..=15`)
482+
/// - Within a block, element `j` (`0 <= j < 16`) and element `j + 16`
483+
/// share byte `j`: low nibble holds element `j`, high nibble holds
484+
/// element `j + 16`. (This is the GGUF interleaved layout, NOT the
485+
/// simple "two consecutive elements per byte" layout used by
486+
/// [`quantize_f32_to_i4`].)
487+
///
488+
/// Returns `(packed_bytes, scales)` where `scales.len() == data.len() / 32`
489+
/// and `packed_bytes.len() == scales.len() * 16`.
490+
///
491+
/// # Panics
492+
///
493+
/// Panics if `data.len()` is not a multiple of [`Q4_0_BLOCK_SIZE`].
494+
pub fn quantize_f32_to_q4_0(data: &[f32]) -> (Vec<u8>, Vec<f32>) {
495+
assert!(
496+
data.len() % Q4_0_BLOCK_SIZE == 0,
497+
"Q4_0 requires data.len() to be a multiple of {}",
498+
Q4_0_BLOCK_SIZE
499+
);
500+
501+
let n_blocks = data.len() / Q4_0_BLOCK_SIZE;
502+
let mut packed = vec![0u8; n_blocks * Q4_0_BYTES_PER_BLOCK];
503+
let mut scales = Vec::with_capacity(n_blocks);
504+
505+
for b in 0..n_blocks {
506+
let block = &data[b * Q4_0_BLOCK_SIZE..(b + 1) * Q4_0_BLOCK_SIZE];
507+
508+
// Find signed max-abs (preserving sign of the extreme element).
509+
let mut amax = 0.0f32;
510+
let mut max_signed = 0.0f32;
511+
for &x in block {
512+
let ax = x.abs();
513+
if ax > amax {
514+
amax = ax;
515+
max_signed = x;
516+
}
517+
}
518+
519+
// d = max_signed / -8 ; if all-zero block, d = 0 and all q = 8.
520+
let d = if amax > 0.0 { max_signed / -8.0 } else { 0.0 };
521+
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
522+
scales.push(d);
523+
524+
let byte_off = b * Q4_0_BYTES_PER_BLOCK;
525+
for j in 0..Q4_0_BYTES_PER_BLOCK {
526+
let lo = ((block[j] * id).round() + 8.5).floor().clamp(0.0, 15.0) as u8;
527+
let hi = ((block[j + Q4_0_BYTES_PER_BLOCK] * id).round() + 8.5)
528+
.floor()
529+
.clamp(0.0, 15.0) as u8;
530+
packed[byte_off + j] = (lo & 0x0F) | ((hi & 0x0F) << 4);
531+
}
532+
}
533+
534+
(packed, scales)
535+
}
536+
537+
/// Dequantize Q4_0 (GGUF-compatible) packed bytes back to f32.
538+
///
539+
/// Inverse of [`quantize_f32_to_q4_0`]. `packed.len()` must equal
540+
/// `scales.len() * 16` and the result has length `scales.len() * 32`.
541+
///
542+
/// # Panics
543+
///
544+
/// Panics if `packed.len() != scales.len() * Q4_0_BYTES_PER_BLOCK`.
545+
pub fn dequantize_q4_0_to_f32(packed: &[u8], scales: &[f32]) -> Vec<f32> {
546+
assert_eq!(
547+
packed.len(),
548+
scales.len() * Q4_0_BYTES_PER_BLOCK,
549+
"Q4_0 packed length must equal scales.len() * {}",
550+
Q4_0_BYTES_PER_BLOCK
551+
);
552+
553+
let n_blocks = scales.len();
554+
let mut out = vec![0.0f32; n_blocks * Q4_0_BLOCK_SIZE];
555+
556+
for b in 0..n_blocks {
557+
let d = scales[b];
558+
let byte_off = b * Q4_0_BYTES_PER_BLOCK;
559+
let elem_off = b * Q4_0_BLOCK_SIZE;
560+
for j in 0..Q4_0_BYTES_PER_BLOCK {
561+
let byte = packed[byte_off + j];
562+
let lo = (byte & 0x0F) as i32 - 8;
563+
let hi = ((byte >> 4) & 0x0F) as i32 - 8;
564+
out[elem_off + j] = lo as f32 * d;
565+
out[elem_off + j + Q4_0_BYTES_PER_BLOCK] = hi as f32 * d;
566+
}
567+
}
568+
569+
out
570+
}
571+
466572
#[cfg(test)]
467573
mod tests {
468574
use super::*;
@@ -567,4 +673,109 @@ mod tests {
567673
// = 0b01_00_11_01 = 0x4D
568674
assert_eq!(packed[0], 0b01_00_11_01);
569675
}
676+
677+
#[test]
678+
fn test_i4_boundary_values() {
679+
// With abs_max=7 -> scale=1.0; the i4 grid maps directly:
680+
// input -7 -> q=-7 -> dequant -7.0
681+
// input 0 -> q= 0 -> dequant 0.0
682+
// input 7 -> q= 7 -> dequant 7.0
683+
let data = vec![-7.0f32, -3.0, 0.0, 3.0, 7.0];
684+
let (packed, params) = quantize_f32_to_i4(&data);
685+
assert!((params.scale - 1.0).abs() < 1e-6);
686+
let recovered = dequantize_i4_to_f32(&packed, &params, data.len());
687+
assert_eq!(recovered, vec![-7.0, -3.0, 0.0, 3.0, 7.0]);
688+
689+
// Negative-end clamp: with abs_max=8 -> scale=8/7, the value -8
690+
// maps to q=-7 (since -8 / (8/7) = -7), dequantizing to -8.0
691+
// exactly. The 8 grid cell hits q=7 -> dequant 8.0 exactly.
692+
let data2 = vec![-8.0f32, 0.0, 8.0];
693+
let (packed2, params2) = quantize_f32_to_i4(&data2);
694+
let s = params2.scale;
695+
assert!((s - 8.0 / 7.0).abs() < 1e-6);
696+
let rec2 = dequantize_i4_to_f32(&packed2, &params2, data2.len());
697+
assert!((rec2[0] - -8.0).abs() < 1e-4);
698+
assert_eq!(rec2[1], 0.0);
699+
assert!((rec2[2] - 8.0).abs() < 1e-4);
700+
}
701+
702+
#[test]
703+
fn test_q4_0_roundtrip_single_block() {
704+
let mut data = Vec::with_capacity(Q4_0_BLOCK_SIZE);
705+
for i in 0..Q4_0_BLOCK_SIZE {
706+
data.push((i as f32) - 16.0); // values in [-16, 15]
707+
}
708+
let (packed, scales) = quantize_f32_to_q4_0(&data);
709+
assert_eq!(packed.len(), Q4_0_BYTES_PER_BLOCK);
710+
assert_eq!(scales.len(), 1);
711+
let recovered = dequantize_q4_0_to_f32(&packed, &scales);
712+
assert_eq!(recovered.len(), data.len());
713+
// Max abs in block is 16. With 4-bit signed grid (16 levels),
714+
// expected error <= |d| ≈ 16/8 = 2.0.
715+
let max_abs = 16.0f32;
716+
let tol = max_abs / 8.0 + 1e-4;
717+
for (i, (orig, rec)) in data.iter().zip(recovered.iter()).enumerate() {
718+
assert!((orig - rec).abs() <= tol, "q4_0 roundtrip[{i}]: {orig} vs {rec} (tol {tol})");
719+
}
720+
}
721+
722+
#[test]
723+
fn test_q4_0_roundtrip_multi_block() {
724+
// 3 blocks (96 elements), monotonically varying values.
725+
let n = 3 * Q4_0_BLOCK_SIZE;
726+
let data: Vec<f32> = (0..n).map(|i| ((i as f32) - 48.0) * 0.25).collect();
727+
let (packed, scales) = quantize_f32_to_q4_0(&data);
728+
assert_eq!(scales.len(), 3);
729+
assert_eq!(packed.len(), 3 * Q4_0_BYTES_PER_BLOCK);
730+
let recovered = dequantize_q4_0_to_f32(&packed, &scales);
731+
assert_eq!(recovered.len(), n);
732+
for (b, &d) in scales.iter().enumerate() {
733+
let tol = d.abs() + 1e-4;
734+
for j in 0..Q4_0_BLOCK_SIZE {
735+
let i = b * Q4_0_BLOCK_SIZE + j;
736+
assert!(
737+
(data[i] - recovered[i]).abs() <= tol,
738+
"q4_0 multi[{i}] block={b}: {} vs {} tol={tol}",
739+
data[i],
740+
recovered[i]
741+
);
742+
}
743+
}
744+
}
745+
746+
#[test]
747+
fn test_q4_0_zero_block() {
748+
let data = vec![0.0f32; Q4_0_BLOCK_SIZE];
749+
let (packed, scales) = quantize_f32_to_q4_0(&data);
750+
assert_eq!(scales[0], 0.0);
751+
let recovered = dequantize_q4_0_to_f32(&packed, &scales);
752+
for v in recovered {
753+
assert_eq!(v, 0.0);
754+
}
755+
}
756+
757+
#[test]
758+
fn test_q4_0_packing_layout_interleaved() {
759+
// Verify GGUF interleaved layout: byte j carries element j (low)
760+
// and element j + 16 (high), within one 32-element block.
761+
let mut data = vec![0.0f32; Q4_0_BLOCK_SIZE];
762+
// Set element 0 to extreme negative so q=15 (since d<0, x/d>0),
763+
// and leave element 16 at 0 so its q=8.
764+
data[0] = -1.0;
765+
data[16] = 0.0;
766+
let (packed, scales) = quantize_f32_to_q4_0(&data);
767+
// d = (-1.0) / -8 = 0.125 ; q[0] = round(-1/0.125)+8 = -8+8 = 0
768+
// q[16] = 0 + 8 = 8
769+
assert!(scales[0] > 0.0);
770+
// byte 0: low nibble = q[0] = 0, high nibble = q[16] = 8
771+
assert_eq!(packed[0] & 0x0F, 0);
772+
assert_eq!((packed[0] >> 4) & 0x0F, 8);
773+
}
774+
775+
#[test]
776+
#[should_panic]
777+
fn test_q4_0_requires_block_aligned() {
778+
let data = vec![1.0f32; Q4_0_BLOCK_SIZE - 1];
779+
let _ = quantize_f32_to_q4_0(&data);
780+
}
570781
}

0 commit comments

Comments
 (0)