@@ -100,6 +100,7 @@ pub fn classify_tensor(name: &str, dims: &[u64]) -> LayerType {
100100// ============================================================================
101101
102102const BASE_DIM : usize = 17 ;
103+ /// Golden-step = round(17 / φ) = round(17 / 1.618) = 11. gcd(11,17)=1 → visits all residues.
103104const GOLDEN_STEP : usize = 11 ;
104105const FP_SCALE : f64 = 256.0 ;
105106
@@ -144,6 +145,293 @@ pub fn project_row_to_base17(row: &[f32]) -> Base17 {
144145 Base17 { dims }
145146}
146147
148+ // ============================================================================
149+ // BF16-direct optimizations: skip f32 intermediate, strided octave sampling
150+ // ============================================================================
151+
152+ /// Halftone-dropped golden positions: keep every other step (9 of 17).
153+ /// Well-distributed across 0..16; max gap = 3. Odd bins interpolated.
154+ const HALFTONE_POS : [ u8 ; 9 ] = {
155+ let mut t = [ 0u8 ; 9 ] ;
156+ let mut i = 0 ;
157+ let mut j = 0 ;
158+ while i < BASE_DIM {
159+ if i % 2 == 0 {
160+ t[ j] = ( ( i * GOLDEN_STEP ) % BASE_DIM ) as u8 ;
161+ j += 1 ;
162+ }
163+ i += 1 ;
164+ }
165+ t
166+ } ;
167+
168+ /// Which of the 17 Base17 bins each halftone position maps to.
169+ const HALFTONE_TO_BIN : [ u8 ; 9 ] = [ 0 , 2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 ] ;
170+
171+ /// Convert one BF16 u16 to f64. Zero allocation.
172+ #[ inline( always) ]
173+ fn bf16_to_f64 ( bits : u16 ) -> f64 {
174+ f32:: from_bits ( ( bits as u32 ) << 16 ) as f64
175+ }
176+
177+ /// Project a BF16 row directly to Base17. No f32 Vec allocated.
178+ ///
179+ /// Same golden-step octave averaging as project_row_to_base17(),
180+ /// but reads u16 BF16 values and converts inline to f64 accumulator.
181+ /// Memory: 17 × f64 accumulators = 136 bytes stack.
182+ pub fn project_row_bf16_direct ( row : & [ u16 ] ) -> Base17 {
183+ let d = row. len ( ) ;
184+ let n_octaves = ( d + BASE_DIM - 1 ) / BASE_DIM ;
185+ let mut sum = [ 0.0f64 ; BASE_DIM ] ;
186+ let mut count = [ 0u32 ; BASE_DIM ] ;
187+
188+ for octave in 0 ..n_octaves {
189+ for bi in 0 ..BASE_DIM {
190+ let dim = octave * BASE_DIM + GOLDEN_POS [ bi] as usize ;
191+ if dim < d {
192+ sum[ bi] += bf16_to_f64 ( row[ dim] ) ;
193+ count[ bi] += 1 ;
194+ }
195+ }
196+ }
197+
198+ let mut dims = [ 0i16 ; BASE_DIM ] ;
199+ for i in 0 ..BASE_DIM {
200+ if count[ i] > 0 {
201+ let mean = sum[ i] / count[ i] as f64 ;
202+ dims[ i] = ( mean * FP_SCALE ) . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
203+ }
204+ }
205+ Base17 { dims }
206+ }
207+
208+ /// Project a BF16 row with octave stride and halftone dropping.
209+ ///
210+ /// For a 5120-element row at stride=16:
211+ /// 302 octaves / 16 = 19 sampled × 9 halftone = 171 BF16→f64 conversions
212+ /// vs 5120 in the full path (97% reduction).
213+ /// Odd bins interpolated from neighbors.
214+ pub fn project_row_bf16_strided ( row : & [ u16 ] , octave_stride : usize ) -> Base17 {
215+ let d = row. len ( ) ;
216+ let n_octaves = ( d + BASE_DIM - 1 ) / BASE_DIM ;
217+
218+ let mut half_sum = [ 0.0f64 ; 9 ] ;
219+ let mut half_count = [ 0u32 ; 9 ] ;
220+
221+ let mut octave = 0 ;
222+ while octave < n_octaves {
223+ for hi in 0 ..9 {
224+ let dim = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
225+ if dim < d {
226+ half_sum[ hi] += bf16_to_f64 ( row[ dim] ) ;
227+ half_count[ hi] += 1 ;
228+ }
229+ }
230+ octave += octave_stride;
231+ }
232+
233+ let mut dims = [ 0i16 ; BASE_DIM ] ;
234+
235+ // Even bins: direct from halftone samples
236+ for hi in 0 ..9 {
237+ let bin = HALFTONE_TO_BIN [ hi] as usize ;
238+ if half_count[ hi] > 0 {
239+ let mean = half_sum[ hi] / half_count[ hi] as f64 ;
240+ dims[ bin] = ( mean * FP_SCALE ) . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
241+ }
242+ }
243+
244+ // Odd bins: interpolate from neighbors (circular)
245+ for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
246+ let left = dims[ odd - 1 ] as i32 ;
247+ let right = dims[ ( odd + 1 ) % BASE_DIM ] as i32 ;
248+ dims[ odd] = ( ( left + right) / 2 ) as i16 ;
249+ }
250+
251+ Base17 { dims }
252+ }
253+
254+ /// Read a BF16 tensor as raw u16 values. NO f32 conversion.
255+ /// `buf` is reusable — caller allocates once, passes to every tensor.
256+ pub fn read_tensor_bf16_raw < R : Read + Seek > (
257+ reader : & mut R ,
258+ gguf_file : & gguf:: GgufFile ,
259+ tensor : & gguf:: TensorInfo ,
260+ buf : & mut Vec < u16 > ,
261+ ) -> Result < usize , String > {
262+ let abs_offset = gguf_file. tensor_data_offset + tensor. offset ;
263+ reader. seek ( std:: io:: SeekFrom :: Start ( abs_offset) ) . map_err ( |e| e. to_string ( ) ) ?;
264+
265+ let n_elements = tensor. element_count ( ) as usize ;
266+ if buf. len ( ) < n_elements {
267+ buf. resize ( n_elements, 0 ) ;
268+ }
269+
270+ // SAFETY: u16 and [u8; 2] have the same layout on little-endian (x86/ARM).
271+ let byte_slice = unsafe {
272+ std:: slice:: from_raw_parts_mut ( buf. as_mut_ptr ( ) as * mut u8 , n_elements * 2 )
273+ } ;
274+ reader. read_exact ( byte_slice) . map_err ( |e| e. to_string ( ) ) ?;
275+
276+ Ok ( n_elements)
277+ }
278+
279+ /// Helper: tensor dimensions → (rows, cols) without needing data.
280+ fn tensor_to_rows_dims ( dims : & [ u64 ] , layer_type : & LayerType ) -> ( usize , usize ) {
281+ match layer_type {
282+ LayerType :: Conv2D if dims. len ( ) == 4 => {
283+ ( dims[ 0 ] as usize , ( dims[ 1 ] * dims[ 2 ] * dims[ 3 ] ) as usize )
284+ }
285+ _ if dims. len ( ) >= 2 => {
286+ let rows = dims[ 0 ] as usize ;
287+ let cols: usize = dims[ 1 ..] . iter ( ) . map ( |& d| d as usize ) . product ( ) ;
288+ ( rows, cols)
289+ }
290+ _ => {
291+ let total: usize = dims. iter ( ) . map ( |& d| d as usize ) . product ( ) ;
292+ ( 1 , total)
293+ }
294+ }
295+ }
296+
297+ /// Helper: LayerType → stats array index.
298+ fn layer_type_index ( lt : & LayerType ) -> usize {
299+ match lt {
300+ LayerType :: Attention => 0 ,
301+ LayerType :: FeedForward => 1 ,
302+ LayerType :: Conv2D => 2 ,
303+ LayerType :: Norm => 3 ,
304+ LayerType :: Embedding => 4 ,
305+ LayerType :: Skip => 5 ,
306+ }
307+ }
308+
309+ /// Stream-index a BF16 GGUF with all optimizations.
310+ ///
311+ /// vs stream_index_gguf():
312+ /// - No f32 Vec allocation (saves 283 MB per tensor)
313+ /// - Reusable u16 buffer (one alloc for entire shard)
314+ /// - Strided octave projection (97% fewer conversions when stride>1)
315+ /// - Direct BF16→f64 inline conversion
316+ ///
317+ /// Falls back to f32 path for non-BF16 dtypes.
318+ pub fn stream_index_gguf_bf16 < R : Read + Seek , W : Write > (
319+ reader : & mut R ,
320+ writer : & mut W ,
321+ octave_stride : usize ,
322+ callback : Option < & dyn Fn ( & str , & LayerType , usize , usize ) > ,
323+ ) -> Result < IndexStats , String > {
324+ let gguf_header = gguf:: read_gguf_header ( reader) ?;
325+ let mut stats = IndexStats :: default ( ) ;
326+ stats. tensors_total = gguf_header. tensors . len ( ) ;
327+
328+ writer. write_all ( b"BGZ7" ) . map_err ( |e| e. to_string ( ) ) ?;
329+ writer. write_all ( & ( gguf_header. tensors . len ( ) as u32 ) . to_le_bytes ( ) ) . map_err ( |e| e. to_string ( ) ) ?;
330+
331+ // ONE reusable buffer — grows to largest tensor, never shrinks
332+ let mut bf16_buf: Vec < u16 > = Vec :: new ( ) ;
333+
334+ for tensor in & gguf_header. tensors {
335+ let layer_type = classify_tensor ( & tensor. name , & tensor. dimensions ) ;
336+
337+ if matches ! ( layer_type, LayerType :: Skip | LayerType :: Norm ) {
338+ stats. tensors_skipped += 1 ;
339+ continue ;
340+ }
341+
342+ let is_bf16 = matches ! ( tensor. dtype, gguf:: GgmlType :: BF16 ) ;
343+
344+ if is_bf16 {
345+ // FAST PATH: BF16 direct — no f32 intermediate
346+ let n_elements = read_tensor_bf16_raw ( reader, & gguf_header, tensor, & mut bf16_buf) ?;
347+ let ( n_rows, n_cols) = tensor_to_rows_dims ( & tensor. dimensions , & layer_type) ;
348+
349+ let mut rows = Vec :: with_capacity ( n_rows) ;
350+ for r in 0 ..n_rows {
351+ let start = r * n_cols;
352+ let end = ( start + n_cols) . min ( n_elements) ;
353+ let row_slice = & bf16_buf[ start..end] ;
354+ let b17 = if octave_stride > 1 {
355+ project_row_bf16_strided ( row_slice, octave_stride)
356+ } else {
357+ project_row_bf16_direct ( row_slice)
358+ } ;
359+ rows. push ( b17) ;
360+ }
361+
362+ let orig_bytes = ( n_rows * n_cols * 4 ) as u64 ;
363+ let comp_bytes = ( rows. len ( ) * Base17 :: BYTE_SIZE ) as u64 ;
364+
365+ let ct = CompressedTensor {
366+ name : tensor. name . clone ( ) ,
367+ layer_type : layer_type. clone ( ) ,
368+ original_shape : tensor. dimensions . clone ( ) ,
369+ n_rows,
370+ n_cols,
371+ rows,
372+ } ;
373+ ct. write_to ( writer) ?;
374+
375+ let lt_idx = layer_type_index ( & layer_type) ;
376+ stats. by_type [ lt_idx] . 0 += 1 ;
377+ stats. by_type [ lt_idx] . 1 += orig_bytes;
378+ stats. by_type [ lt_idx] . 2 += comp_bytes;
379+ stats. original_bytes += orig_bytes;
380+ stats. compressed_bytes += comp_bytes;
381+ stats. tensors_indexed += 1 ;
382+
383+ let peak = n_elements as u64 * 2 ;
384+ if peak > stats. peak_tensor_bytes { stats. peak_tensor_bytes = peak; }
385+
386+ if let Some ( cb) = callback {
387+ cb ( & tensor. name , & layer_type, orig_bytes as usize , comp_bytes as usize ) ;
388+ }
389+ } else {
390+ // FALLBACK: non-BF16 — use original f32 path
391+ let data = gguf:: read_tensor_f32 ( reader, & gguf_header, tensor) ?;
392+ let tensor_bytes = data. len ( ) as u64 * 4 ;
393+ if tensor_bytes > stats. peak_tensor_bytes {
394+ stats. peak_tensor_bytes = tensor_bytes;
395+ }
396+
397+ let ( n_rows, n_cols) = tensor_to_rows ( & data, & tensor. dimensions , & layer_type) ;
398+ let mut rows = Vec :: with_capacity ( n_rows) ;
399+ for r in 0 ..n_rows {
400+ let start = r * n_cols;
401+ let end = ( start + n_cols) . min ( data. len ( ) ) ;
402+ rows. push ( project_row_to_base17 ( & data[ start..end] ) ) ;
403+ }
404+
405+ let orig_bytes = ( n_rows * n_cols * 4 ) as u64 ;
406+ let comp_bytes = ( rows. len ( ) * Base17 :: BYTE_SIZE ) as u64 ;
407+
408+ let ct = CompressedTensor {
409+ name : tensor. name . clone ( ) ,
410+ layer_type : layer_type. clone ( ) ,
411+ original_shape : tensor. dimensions . clone ( ) ,
412+ n_rows,
413+ n_cols,
414+ rows,
415+ } ;
416+ ct. write_to ( writer) ?;
417+
418+ let lt_idx = layer_type_index ( & layer_type) ;
419+ stats. by_type [ lt_idx] . 0 += 1 ;
420+ stats. by_type [ lt_idx] . 1 += orig_bytes;
421+ stats. by_type [ lt_idx] . 2 += comp_bytes;
422+ stats. original_bytes += orig_bytes;
423+ stats. compressed_bytes += comp_bytes;
424+ stats. tensors_indexed += 1 ;
425+
426+ if let Some ( cb) = callback {
427+ cb ( & tensor. name , & layer_type, orig_bytes as usize , comp_bytes as usize ) ;
428+ }
429+ }
430+ }
431+
432+ Ok ( stats)
433+ }
434+
147435// ============================================================================
148436// Compressed tensor output
149437// ============================================================================
@@ -711,4 +999,53 @@ mod tests {
711999 #[ test]
7121000 #[ ignore]
7131001 fn test_stream_index_llama4_bf16_shard5 ( ) { run_llama4_shard ( 5 ) ; }
1002+
1003+ // ── BF16-direct optimization tests ──
1004+
1005+ #[ test]
1006+ fn test_halftone_positions_coverage ( ) {
1007+ let positions: Vec < u8 > = HALFTONE_POS . to_vec ( ) ;
1008+ let mut sorted = positions. clone ( ) ;
1009+ sorted. sort ( ) ;
1010+ assert_eq ! ( sorted, vec![ 0 , 1 , 3 , 5 , 6 , 8 , 10 , 13 , 15 ] ) ;
1011+ }
1012+
1013+ #[ test]
1014+ fn test_bf16_to_f64_accuracy ( ) {
1015+ assert_eq ! ( bf16_to_f64( 0x3F80 ) , 1.0 ) ;
1016+ assert_eq ! ( bf16_to_f64( 0x0000 ) , 0.0 ) ;
1017+ assert_eq ! ( bf16_to_f64( 0xBF80 ) , -1.0 ) ;
1018+ let v = bf16_to_f64 ( 0x4049 ) ;
1019+ assert ! ( ( v - 3.140625 ) . abs( ) < 0.01 ) ;
1020+ }
1021+
1022+ #[ test]
1023+ fn test_strided_vs_full_agreement ( ) {
1024+ // Constant BF16 row → stride shouldn't matter
1025+ let row: Vec < u16 > = vec ! [ 0x3F80 ; 5120 ] ; // all 1.0
1026+ let full = project_row_bf16_direct ( & row) ;
1027+ let strided = project_row_bf16_strided ( & row, 16 ) ;
1028+
1029+ for i in 0 ..17 {
1030+ let diff = ( full. dims [ i] as i32 - strided. dims [ i] as i32 ) . abs ( ) ;
1031+ assert ! ( diff <= 1 , "bin {} differs by {}: full={}, strided={}" ,
1032+ i, diff, full. dims[ i] , strided. dims[ i] ) ;
1033+ }
1034+ }
1035+
1036+ #[ test]
1037+ fn test_bf16_direct_matches_f32_path ( ) {
1038+ // Same data in BF16 and f32 should produce identical Base17
1039+ let f32_row: Vec < f32 > = ( 0 ..4096 ) . map ( |i| ( i as f32 ) * 0.001 ) . collect ( ) ;
1040+ let bf16_row: Vec < u16 > = f32_row. iter ( ) . map ( |& v| ( v. to_bits ( ) >> 16 ) as u16 ) . collect ( ) ;
1041+
1042+ let from_f32 = project_row_to_base17 ( & f32_row) ;
1043+ let from_bf16 = project_row_bf16_direct ( & bf16_row) ;
1044+
1045+ // BF16 truncates mantissa, so allow ±1 tolerance per dim
1046+ for i in 0 ..17 {
1047+ let diff = ( from_f32. dims [ i] as i32 - from_bf16. dims [ i] as i32 ) . abs ( ) ;
1048+ assert ! ( diff <= 2 , "bin {} differs by {}" , i, diff) ;
1049+ }
1050+ }
7141051}
0 commit comments