@@ -463,6 +463,112 @@ pub fn dequantize_i2_to_f32(packed: &[u8], params: &QuantParams, n: usize) -> Ve
463463 out
464464}
465465
466+ // ── Q4_0 (GGUF-compatible block quantization) ─────────────────────
467+
468+ /// Q4_0 block size (number of f32 elements per block).
469+ pub const Q4_0_BLOCK_SIZE : usize = 32 ;
470+
471+ /// Number of bytes used to pack one Q4_0 block (32 nibbles = 16 bytes).
472+ pub const Q4_0_BYTES_PER_BLOCK : usize = Q4_0_BLOCK_SIZE / 2 ;
473+
474+ /// Quantize f32 to Q4_0 (GGUF-compatible) per-block 4-bit quantization.
475+ ///
476+ /// Each block of [`Q4_0_BLOCK_SIZE`] (32) f32 elements is encoded as
477+ /// [`Q4_0_BYTES_PER_BLOCK`] (16) packed bytes plus one f32 scale `d`.
478+ ///
479+ /// Encoding (matches `llama.cpp` / GGUF reference):
480+ /// - `d = max(|x_i|) / -8.0` (negated max-abs, signed)
481+ /// - `q_i = clamp(round(x_i / d) + 8, 0, 15)` (unsigned nibble in `0..=15`)
482+ /// - Within a block, element `j` (`0 <= j < 16`) and element `j + 16`
483+ /// share byte `j`: low nibble holds element `j`, high nibble holds
484+ /// element `j + 16`. (This is the GGUF interleaved layout, NOT the
485+ /// simple "two consecutive elements per byte" layout used by
486+ /// [`quantize_f32_to_i4`].)
487+ ///
488+ /// Returns `(packed_bytes, scales)` where `scales.len() == data.len() / 32`
489+ /// and `packed_bytes.len() == scales.len() * 16`.
490+ ///
491+ /// # Panics
492+ ///
493+ /// Panics if `data.len()` is not a multiple of [`Q4_0_BLOCK_SIZE`].
494+ pub fn quantize_f32_to_q4_0 ( data : & [ f32 ] ) -> ( Vec < u8 > , Vec < f32 > ) {
495+ assert ! (
496+ data. len( ) % Q4_0_BLOCK_SIZE == 0 ,
497+ "Q4_0 requires data.len() to be a multiple of {}" ,
498+ Q4_0_BLOCK_SIZE
499+ ) ;
500+
501+ let n_blocks = data. len ( ) / Q4_0_BLOCK_SIZE ;
502+ let mut packed = vec ! [ 0u8 ; n_blocks * Q4_0_BYTES_PER_BLOCK ] ;
503+ let mut scales = Vec :: with_capacity ( n_blocks) ;
504+
505+ for b in 0 ..n_blocks {
506+ let block = & data[ b * Q4_0_BLOCK_SIZE ..( b + 1 ) * Q4_0_BLOCK_SIZE ] ;
507+
508+ // Find signed max-abs (preserving sign of the extreme element).
509+ let mut amax = 0.0f32 ;
510+ let mut max_signed = 0.0f32 ;
511+ for & x in block {
512+ let ax = x. abs ( ) ;
513+ if ax > amax {
514+ amax = ax;
515+ max_signed = x;
516+ }
517+ }
518+
519+ // d = max_signed / -8 ; if all-zero block, d = 0 and all q = 8.
520+ let d = if amax > 0.0 { max_signed / -8.0 } else { 0.0 } ;
521+ let id = if d != 0.0 { 1.0 / d } else { 0.0 } ;
522+ scales. push ( d) ;
523+
524+ let byte_off = b * Q4_0_BYTES_PER_BLOCK ;
525+ for j in 0 ..Q4_0_BYTES_PER_BLOCK {
526+ let lo = ( ( block[ j] * id) . round ( ) + 8.5 ) . floor ( ) . clamp ( 0.0 , 15.0 ) as u8 ;
527+ let hi = ( ( block[ j + Q4_0_BYTES_PER_BLOCK ] * id) . round ( ) + 8.5 )
528+ . floor ( )
529+ . clamp ( 0.0 , 15.0 ) as u8 ;
530+ packed[ byte_off + j] = ( lo & 0x0F ) | ( ( hi & 0x0F ) << 4 ) ;
531+ }
532+ }
533+
534+ ( packed, scales)
535+ }
536+
537+ /// Dequantize Q4_0 (GGUF-compatible) packed bytes back to f32.
538+ ///
539+ /// Inverse of [`quantize_f32_to_q4_0`]. `packed.len()` must equal
540+ /// `scales.len() * 16` and the result has length `scales.len() * 32`.
541+ ///
542+ /// # Panics
543+ ///
544+ /// Panics if `packed.len() != scales.len() * Q4_0_BYTES_PER_BLOCK`.
545+ pub fn dequantize_q4_0_to_f32 ( packed : & [ u8 ] , scales : & [ f32 ] ) -> Vec < f32 > {
546+ assert_eq ! (
547+ packed. len( ) ,
548+ scales. len( ) * Q4_0_BYTES_PER_BLOCK ,
549+ "Q4_0 packed length must equal scales.len() * {}" ,
550+ Q4_0_BYTES_PER_BLOCK
551+ ) ;
552+
553+ let n_blocks = scales. len ( ) ;
554+ let mut out = vec ! [ 0.0f32 ; n_blocks * Q4_0_BLOCK_SIZE ] ;
555+
556+ for b in 0 ..n_blocks {
557+ let d = scales[ b] ;
558+ let byte_off = b * Q4_0_BYTES_PER_BLOCK ;
559+ let elem_off = b * Q4_0_BLOCK_SIZE ;
560+ for j in 0 ..Q4_0_BYTES_PER_BLOCK {
561+ let byte = packed[ byte_off + j] ;
562+ let lo = ( byte & 0x0F ) as i32 - 8 ;
563+ let hi = ( ( byte >> 4 ) & 0x0F ) as i32 - 8 ;
564+ out[ elem_off + j] = lo as f32 * d;
565+ out[ elem_off + j + Q4_0_BYTES_PER_BLOCK ] = hi as f32 * d;
566+ }
567+ }
568+
569+ out
570+ }
571+
466572#[ cfg( test) ]
467573mod tests {
468574 use super :: * ;
@@ -567,4 +673,109 @@ mod tests {
567673 // = 0b01_00_11_01 = 0x4D
568674 assert_eq ! ( packed[ 0 ] , 0b01_00_11_01 ) ;
569675 }
676+
677+ #[ test]
678+ fn test_i4_boundary_values ( ) {
679+ // With abs_max=7 -> scale=1.0; the i4 grid maps directly:
680+ // input -7 -> q=-7 -> dequant -7.0
681+ // input 0 -> q= 0 -> dequant 0.0
682+ // input 7 -> q= 7 -> dequant 7.0
683+ let data = vec ! [ -7.0f32 , -3.0 , 0.0 , 3.0 , 7.0 ] ;
684+ let ( packed, params) = quantize_f32_to_i4 ( & data) ;
685+ assert ! ( ( params. scale - 1.0 ) . abs( ) < 1e-6 ) ;
686+ let recovered = dequantize_i4_to_f32 ( & packed, & params, data. len ( ) ) ;
687+ assert_eq ! ( recovered, vec![ -7.0 , -3.0 , 0.0 , 3.0 , 7.0 ] ) ;
688+
689+ // Negative-end clamp: with abs_max=8 -> scale=8/7, the value -8
690+ // maps to q=-7 (since -8 / (8/7) = -7), dequantizing to -8.0
691+ // exactly. The 8 grid cell hits q=7 -> dequant 8.0 exactly.
692+ let data2 = vec ! [ -8.0f32 , 0.0 , 8.0 ] ;
693+ let ( packed2, params2) = quantize_f32_to_i4 ( & data2) ;
694+ let s = params2. scale ;
695+ assert ! ( ( s - 8.0 / 7.0 ) . abs( ) < 1e-6 ) ;
696+ let rec2 = dequantize_i4_to_f32 ( & packed2, & params2, data2. len ( ) ) ;
697+ assert ! ( ( rec2[ 0 ] - -8.0 ) . abs( ) < 1e-4 ) ;
698+ assert_eq ! ( rec2[ 1 ] , 0.0 ) ;
699+ assert ! ( ( rec2[ 2 ] - 8.0 ) . abs( ) < 1e-4 ) ;
700+ }
701+
702+ #[ test]
703+ fn test_q4_0_roundtrip_single_block ( ) {
704+ let mut data = Vec :: with_capacity ( Q4_0_BLOCK_SIZE ) ;
705+ for i in 0 ..Q4_0_BLOCK_SIZE {
706+ data. push ( ( i as f32 ) - 16.0 ) ; // values in [-16, 15]
707+ }
708+ let ( packed, scales) = quantize_f32_to_q4_0 ( & data) ;
709+ assert_eq ! ( packed. len( ) , Q4_0_BYTES_PER_BLOCK ) ;
710+ assert_eq ! ( scales. len( ) , 1 ) ;
711+ let recovered = dequantize_q4_0_to_f32 ( & packed, & scales) ;
712+ assert_eq ! ( recovered. len( ) , data. len( ) ) ;
713+ // Max abs in block is 16. With 4-bit signed grid (16 levels),
714+ // expected error <= |d| ≈ 16/8 = 2.0.
715+ let max_abs = 16.0f32 ;
716+ let tol = max_abs / 8.0 + 1e-4 ;
717+ for ( i, ( orig, rec) ) in data. iter ( ) . zip ( recovered. iter ( ) ) . enumerate ( ) {
718+ assert ! ( ( orig - rec) . abs( ) <= tol, "q4_0 roundtrip[{i}]: {orig} vs {rec} (tol {tol})" ) ;
719+ }
720+ }
721+
722+ #[ test]
723+ fn test_q4_0_roundtrip_multi_block ( ) {
724+ // 3 blocks (96 elements), monotonically varying values.
725+ let n = 3 * Q4_0_BLOCK_SIZE ;
726+ let data: Vec < f32 > = ( 0 ..n) . map ( |i| ( ( i as f32 ) - 48.0 ) * 0.25 ) . collect ( ) ;
727+ let ( packed, scales) = quantize_f32_to_q4_0 ( & data) ;
728+ assert_eq ! ( scales. len( ) , 3 ) ;
729+ assert_eq ! ( packed. len( ) , 3 * Q4_0_BYTES_PER_BLOCK ) ;
730+ let recovered = dequantize_q4_0_to_f32 ( & packed, & scales) ;
731+ assert_eq ! ( recovered. len( ) , n) ;
732+ for ( b, & d) in scales. iter ( ) . enumerate ( ) {
733+ let tol = d. abs ( ) + 1e-4 ;
734+ for j in 0 ..Q4_0_BLOCK_SIZE {
735+ let i = b * Q4_0_BLOCK_SIZE + j;
736+ assert ! (
737+ ( data[ i] - recovered[ i] ) . abs( ) <= tol,
738+ "q4_0 multi[{i}] block={b}: {} vs {} tol={tol}" ,
739+ data[ i] ,
740+ recovered[ i]
741+ ) ;
742+ }
743+ }
744+ }
745+
746+ #[ test]
747+ fn test_q4_0_zero_block ( ) {
748+ let data = vec ! [ 0.0f32 ; Q4_0_BLOCK_SIZE ] ;
749+ let ( packed, scales) = quantize_f32_to_q4_0 ( & data) ;
750+ assert_eq ! ( scales[ 0 ] , 0.0 ) ;
751+ let recovered = dequantize_q4_0_to_f32 ( & packed, & scales) ;
752+ for v in recovered {
753+ assert_eq ! ( v, 0.0 ) ;
754+ }
755+ }
756+
757+ #[ test]
758+ fn test_q4_0_packing_layout_interleaved ( ) {
759+ // Verify GGUF interleaved layout: byte j carries element j (low)
760+ // and element j + 16 (high), within one 32-element block.
761+ let mut data = vec ! [ 0.0f32 ; Q4_0_BLOCK_SIZE ] ;
762+ // Set element 0 to extreme negative so q=15 (since d<0, x/d>0),
763+ // and leave element 16 at 0 so its q=8.
764+ data[ 0 ] = -1.0 ;
765+ data[ 16 ] = 0.0 ;
766+ let ( packed, scales) = quantize_f32_to_q4_0 ( & data) ;
767+ // d = (-1.0) / -8 = 0.125 ; q[0] = round(-1/0.125)+8 = -8+8 = 0
768+ // q[16] = 0 + 8 = 8
769+ assert ! ( scales[ 0 ] > 0.0 ) ;
770+ // byte 0: low nibble = q[0] = 0, high nibble = q[16] = 8
771+ assert_eq ! ( packed[ 0 ] & 0x0F , 0 ) ;
772+ assert_eq ! ( ( packed[ 0 ] >> 4 ) & 0x0F , 8 ) ;
773+ }
774+
775+ #[ test]
776+ #[ should_panic]
777+ fn test_q4_0_requires_block_aligned ( ) {
778+ let data = vec ! [ 1.0f32 ; Q4_0_BLOCK_SIZE - 1 ] ;
779+ let _ = quantize_f32_to_q4_0 ( & data) ;
780+ }
570781}
0 commit comments