@@ -630,329 +630,6 @@ pub fn stream_index_gguf<R: Read + Seek, W: Write>(
630630 Ok ( stats)
631631}
632632
633- // ============================================================================
634- // BF16-DIRECT OPTIMIZATIONS
635- // ============================================================================
636- //
637- // Skip f32 intermediate entirely for BF16 tensors:
638- // Old: alloc Vec<u8> + alloc Vec<f32> + batch dequant + project (424 MB peak)
639- // New: alloc Vec<u16> (reused) + inline BF16→f64 at sample sites (141 MB peak)
640- // CPU: 97% fewer BF16→f64 conversions with octave stride + halftone drop
641- // ============================================================================
642-
643- /// Halftone-dropped golden positions: every other step from GOLDEN_POS.
644- /// 9 positions, still well-distributed across 0..16.
645- const HALFTONE_POS : [ u8 ; 9 ] = {
646- let mut t = [ 0u8 ; 9 ] ;
647- let mut i = 0 ;
648- let mut j = 0 ;
649- while i < BASE_DIM {
650- if i % 2 == 0 {
651- t[ j] = ( ( i * GOLDEN_STEP ) % BASE_DIM ) as u8 ;
652- j += 1 ;
653- }
654- i += 1 ;
655- }
656- t
657- } ;
658-
659- /// Which Base17 bin each halftone sample maps to (even-indexed bins).
660- const HALFTONE_TO_BIN : [ u8 ; 9 ] = [ 0 , 2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 ] ;
661-
662- // ── Core: inline BF16 → f64 (zero allocation) ──
663-
664- /// Convert one BF16 u16 to f64. Zero allocation. 2 instructions.
665- ///
666- /// BF16 = upper 16 bits of IEEE 754 f32.
667- /// Shift left 16 → f32 bit pattern → extend to f64.
668- #[ inline( always) ]
669- fn bf16_to_f64 ( bits : u16 ) -> f64 {
670- f32:: from_bits ( ( bits as u32 ) << 16 ) as f64
671- }
672-
673- // ── BF16-direct projection (full octave, no f32 intermediate) ──
674-
675- /// Project a BF16 row directly to Base17. No f32 Vec allocated.
676- ///
677- /// Same golden-step octave averaging as project_row_to_base17(),
678- /// but reads u16 BF16 values and converts inline to f64 accumulator.
679- ///
680- /// Memory: 17 × f64 accumulators = 136 bytes stack. That's it.
681- pub fn project_row_bf16_direct ( row : & [ u16 ] ) -> Base17 {
682- let d = row. len ( ) ;
683- let n_octaves = ( d + BASE_DIM - 1 ) / BASE_DIM ;
684- let mut sum = [ 0.0f64 ; BASE_DIM ] ;
685- let mut count = [ 0u32 ; BASE_DIM ] ;
686-
687- for octave in 0 ..n_octaves {
688- for bi in 0 ..BASE_DIM {
689- let dim = octave * BASE_DIM + GOLDEN_POS [ bi] as usize ;
690- if dim < d {
691- sum[ bi] += bf16_to_f64 ( row[ dim] ) ;
692- count[ bi] += 1 ;
693- }
694- }
695- }
696-
697- let mut dims = [ 0i16 ; BASE_DIM ] ;
698- for i in 0 ..BASE_DIM {
699- if count[ i] > 0 {
700- let mean = sum[ i] / count[ i] as f64 ;
701- dims[ i] = ( mean * FP_SCALE ) . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
702- }
703- }
704- Base17 { dims }
705- }
706-
707- // ── Strided octave + halftone drop (the big win) ──
708-
709- /// Project a BF16 row with octave stride and halftone dropping.
710- ///
711- /// For a 5120-element row at stride=16:
712- /// 302 octaves / 16 = 19 sampled octaves
713- /// 19 octaves × 9 halftone positions = 171 BF16→f64 conversions
714- /// vs 5120 conversions in the full path (97% reduction)
715- ///
716- /// Odd bins are interpolated as average of their two neighbors.
717- pub fn project_row_bf16_strided ( row : & [ u16 ] , octave_stride : usize ) -> Base17 {
718- let d = row. len ( ) ;
719- let n_octaves = ( d + BASE_DIM - 1 ) / BASE_DIM ;
720-
721- // Phase 1: accumulate halftone samples into 9 bins
722- let mut half_sum = [ 0.0f64 ; 9 ] ;
723- let mut half_count = [ 0u32 ; 9 ] ;
724-
725- let mut octave = 0 ;
726- while octave < n_octaves {
727- for hi in 0 ..9 {
728- let dim = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
729- if dim < d {
730- half_sum[ hi] += bf16_to_f64 ( row[ dim] ) ;
731- half_count[ hi] += 1 ;
732- }
733- }
734- octave += octave_stride;
735- }
736-
737- // Phase 2: fill 17 bins — sampled bins from data, gaps interpolated
738- let mut dims = [ 0i16 ; BASE_DIM ] ;
739-
740- // Even bins: direct from halftone samples
741- for hi in 0 ..9 {
742- let bin = HALFTONE_TO_BIN [ hi] as usize ;
743- if half_count[ hi] > 0 {
744- let mean = half_sum[ hi] / half_count[ hi] as f64 ;
745- dims[ bin] = ( mean * FP_SCALE ) . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
746- }
747- }
748-
749- // Odd bins: interpolate from neighbors (circular)
750- for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
751- let left = dims[ odd - 1 ] as i32 ;
752- let right = dims[ ( odd + 1 ) % BASE_DIM ] as i32 ;
753- dims[ odd] = ( ( left + right) / 2 ) as i16 ;
754- }
755-
756- Base17 { dims }
757- }
758-
759- // ── Read tensor as raw u16 (skip f32 allocation entirely) ──
760-
761- /// Read a BF16 tensor as raw u16 values. NO f32 conversion.
762- ///
763- /// `buf` is a REUSABLE buffer — caller allocates once, passes to every tensor.
764- /// Grows to max tensor, never shrinks. Saves 283 MB per tensor vs f32 path.
765- pub fn read_tensor_bf16_raw < R : Read + Seek > (
766- reader : & mut R ,
767- gguf : & GgufFile ,
768- tensor : & TensorInfo ,
769- buf : & mut Vec < u16 > ,
770- ) -> Result < usize , String > {
771- let abs_offset = gguf. tensor_data_offset + tensor. offset ;
772- reader. seek ( SeekFrom :: Start ( abs_offset) ) . map_err ( |e| e. to_string ( ) ) ?;
773-
774- let n_elements = tensor. element_count ( ) as usize ;
775-
776- if buf. len ( ) < n_elements {
777- buf. resize ( n_elements, 0 ) ;
778- }
779-
780- // SAFETY: u16 and [u8; 2] have the same layout on little-endian.
781- // GGUF BF16 tensors are stored as little-endian u16 pairs.
782- let byte_slice = unsafe {
783- std:: slice:: from_raw_parts_mut (
784- buf. as_mut_ptr ( ) as * mut u8 ,
785- n_elements * 2 ,
786- )
787- } ;
788- reader. read_exact ( byte_slice) . map_err ( |e| e. to_string ( ) ) ?;
789-
790- Ok ( n_elements)
791- }
792-
793- // ── Helper: tensor_to_rows from dimensions only (no data needed for BF16 path) ──
794-
795- fn tensor_to_rows_dims ( dims : & [ u64 ] , layer_type : & LayerType ) -> ( usize , usize ) {
796- match layer_type {
797- LayerType :: Conv2D if dims. len ( ) == 4 => {
798- ( dims[ 0 ] as usize , ( dims[ 1 ] * dims[ 2 ] * dims[ 3 ] ) as usize )
799- }
800- _ if dims. len ( ) >= 2 => {
801- let rows = dims[ 0 ] as usize ;
802- let cols: usize = dims[ 1 ..] . iter ( ) . map ( |& d| d as usize ) . product ( ) ;
803- ( rows, cols)
804- }
805- _ => {
806- let total: usize = dims. iter ( ) . map ( |& d| d as usize ) . product ( ) ;
807- ( 1 , total)
808- }
809- }
810- }
811-
812- /// Helper: LayerType → array index.
813- fn layer_type_index ( lt : & LayerType ) -> usize {
814- match lt {
815- LayerType :: Attention => 0 ,
816- LayerType :: FeedForward => 1 ,
817- LayerType :: Conv2D => 2 ,
818- LayerType :: Norm => 3 ,
819- LayerType :: Embedding => 4 ,
820- LayerType :: Skip => 5 ,
821- }
822- }
823-
824- // ── Combined BF16-direct streaming indexer ──
825-
826- /// Stream-index a BF16 GGUF file with all optimizations.
827- ///
828- /// - No f32 Vec allocation (saves 283 MB per tensor)
829- /// - Reusable u16 buffer (one alloc for entire shard)
830- /// - Strided octave projection (97% fewer conversions when stride>1)
831- /// - Direct BF16→f64 inline conversion (no batch bf16_to_f32_slice)
832- ///
833- /// `octave_stride`: 1 = full (identical to original), 16 = 4 octaves higher
834- pub fn stream_index_gguf_bf16 < R : Read + Seek , W : Write > (
835- reader : & mut R ,
836- writer : & mut W ,
837- octave_stride : usize ,
838- callback : Option < & dyn Fn ( & str , & LayerType , usize , usize ) > ,
839- ) -> Result < IndexStats , String > {
840- let gguf = gguf:: read_gguf_header ( reader) ?;
841- let mut stats = IndexStats :: default ( ) ;
842- stats. tensors_total = gguf. tensors . len ( ) ;
843-
844- writer. write_all ( b"BGZ7" ) . map_err ( |e| e. to_string ( ) ) ?;
845- writer. write_all ( & ( gguf. tensors . len ( ) as u32 ) . to_le_bytes ( ) ) . map_err ( |e| e. to_string ( ) ) ?;
846-
847- // ONE reusable buffer — grows to largest tensor, never shrinks
848- let mut bf16_buf: Vec < u16 > = Vec :: new ( ) ;
849-
850- for tensor in & gguf. tensors {
851- let layer_type = classify_tensor ( & tensor. name , & tensor. dimensions ) ;
852-
853- if matches ! ( layer_type, LayerType :: Skip | LayerType :: Norm ) {
854- stats. tensors_skipped += 1 ;
855- continue ;
856- }
857-
858- let is_bf16 = matches ! ( tensor. dtype, GgmlType :: BF16 ) ;
859-
860- if is_bf16 {
861- // FAST PATH: BF16 direct — no f32 intermediate
862- let n_elements = read_tensor_bf16_raw ( reader, & gguf, tensor, & mut bf16_buf) ?;
863-
864- let ( n_rows, n_cols) = tensor_to_rows_dims ( & tensor. dimensions , & layer_type) ;
865- let orig_bytes = ( n_rows * n_cols * 4 ) as u64 ; // f32 equivalent
866-
867- let mut rows = Vec :: with_capacity ( n_rows) ;
868- for r in 0 ..n_rows {
869- let start = r * n_cols;
870- let end = ( start + n_cols) . min ( n_elements) ;
871- let row_slice = & bf16_buf[ start..end] ;
872-
873- let b17 = if octave_stride > 1 {
874- project_row_bf16_strided ( row_slice, octave_stride)
875- } else {
876- project_row_bf16_direct ( row_slice)
877- } ;
878- rows. push ( b17) ;
879- }
880-
881- let comp_bytes = ( rows. len ( ) * Base17 :: BYTE_SIZE ) as u64 ;
882-
883- let ct = CompressedTensor {
884- name : tensor. name . clone ( ) ,
885- layer_type : layer_type. clone ( ) ,
886- original_shape : tensor. dimensions . clone ( ) ,
887- n_rows,
888- n_cols,
889- rows,
890- } ;
891- ct. write_to ( writer) ?;
892-
893- let lt_idx = layer_type_index ( & layer_type) ;
894- stats. by_type [ lt_idx] . 0 += 1 ;
895- stats. by_type [ lt_idx] . 1 += orig_bytes;
896- stats. by_type [ lt_idx] . 2 += comp_bytes;
897- stats. original_bytes += orig_bytes;
898- stats. compressed_bytes += comp_bytes;
899- stats. tensors_indexed += 1 ;
900-
901- if n_elements as u64 * 2 > stats. peak_tensor_bytes {
902- stats. peak_tensor_bytes = n_elements as u64 * 2 ;
903- }
904-
905- if let Some ( cb) = callback {
906- cb ( & tensor. name , & layer_type, orig_bytes as usize , comp_bytes as usize ) ;
907- }
908- } else {
909- // FALLBACK: non-BF16 dtype — use original f32 path
910- let data = gguf:: read_tensor_f32 ( reader, & gguf, tensor) ?;
911-
912- let tensor_bytes = data. len ( ) as u64 * 4 ;
913- if tensor_bytes > stats. peak_tensor_bytes {
914- stats. peak_tensor_bytes = tensor_bytes;
915- }
916-
917- let ( n_rows, n_cols) = tensor_to_rows ( & data, & tensor. dimensions , & layer_type) ;
918-
919- let mut rows = Vec :: with_capacity ( n_rows) ;
920- for r in 0 ..n_rows {
921- let start = r * n_cols;
922- let end = ( start + n_cols) . min ( data. len ( ) ) ;
923- rows. push ( project_row_to_base17 ( & data[ start..end] ) ) ;
924- }
925-
926- let orig_bytes = ( n_rows * n_cols * 4 ) as u64 ;
927- let comp_bytes = ( rows. len ( ) * Base17 :: BYTE_SIZE ) as u64 ;
928-
929- let ct = CompressedTensor {
930- name : tensor. name . clone ( ) ,
931- layer_type : layer_type. clone ( ) ,
932- original_shape : tensor. dimensions . clone ( ) ,
933- n_rows,
934- n_cols,
935- rows,
936- } ;
937- ct. write_to ( writer) ?;
938-
939- let lt_idx = layer_type_index ( & layer_type) ;
940- stats. by_type [ lt_idx] . 0 += 1 ;
941- stats. by_type [ lt_idx] . 1 += orig_bytes;
942- stats. by_type [ lt_idx] . 2 += comp_bytes;
943- stats. original_bytes += orig_bytes;
944- stats. compressed_bytes += comp_bytes;
945- stats. tensors_indexed += 1 ;
946-
947- if let Some ( cb) = callback {
948- cb ( & tensor. name , & layer_type, orig_bytes as usize , comp_bytes as usize ) ;
949- }
950- }
951- }
952-
953- Ok ( stats)
954- }
955-
956633// ============================================================================
957634// Tests
958635// ============================================================================
@@ -1372,35 +1049,6 @@ mod tests {
13721049 }
13731050 }
13741051
1375- #[ test]
1376- fn test_halftone_positions_coverage ( ) {
1377- let positions: Vec < u8 > = HALFTONE_POS . to_vec ( ) ;
1378- let mut sorted = positions. clone ( ) ;
1379- sorted. sort ( ) ;
1380- assert_eq ! ( sorted, vec![ 0 , 1 , 3 , 5 , 6 , 8 , 10 , 13 , 15 ] ) ;
1381- }
1382-
1383- #[ test]
1384- fn test_bf16_to_f64_accuracy ( ) {
1385- assert_eq ! ( bf16_to_f64( 0x3F80 ) , 1.0 ) ;
1386- assert_eq ! ( bf16_to_f64( 0x0000 ) , 0.0 ) ;
1387- assert_eq ! ( bf16_to_f64( 0xBF80 ) , -1.0 ) ;
1388- let v = bf16_to_f64 ( 0x4049 ) ;
1389- assert ! ( ( v - 3.140625 ) . abs( ) < 0.01 ) ;
1390- }
1391-
1392- #[ test]
1393- fn test_strided_vs_full_agreement ( ) {
1394- let row: Vec < u16 > = vec ! [ 0x3F80 ; 5120 ] ; // all 1.0 in BF16
1395- let full = project_row_bf16_direct ( & row) ;
1396- let strided = project_row_bf16_strided ( & row, 16 ) ;
1397- for i in 0 ..BASE_DIM {
1398- let diff = ( full. dims [ i] as i32 - strided. dims [ i] as i32 ) . abs ( ) ;
1399- assert ! ( diff <= 1 , "bin {} differs by {}: full={}, strided={}" ,
1400- i, diff, full. dims[ i] , strided. dims[ i] ) ;
1401- }
1402- }
1403-
14041052 #[ test]
14051053 #[ ignore] // Streams ~801 GB from HuggingFace
14061054 fn test_stream_index_llama4_maverick_bf16_all_shards ( ) {
0 commit comments