@@ -251,6 +251,103 @@ pub fn project_row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
251251 Base17 { dims }
252252}
253253
254+ // ── SIMD 8-row-parallel tensor projection ──
255+
256+ /// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
257+ ///
258+ /// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
259+ /// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
260+ ///
261+ /// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
262+ /// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
263+ pub fn project_tensor_bf16_simd (
264+ buf : & [ u16 ] ,
265+ n_rows : usize ,
266+ n_cols : usize ,
267+ octave_stride : usize ,
268+ ) -> Vec < Base17 > {
269+ use crate :: simd:: F64x8 ;
270+
271+ let n_octaves = ( n_cols + BASE_DIM - 1 ) / BASE_DIM ;
272+ let mut result = Vec :: with_capacity ( n_rows) ;
273+
274+ // Process 8 rows at a time
275+ let full_batches = n_rows / 8 ;
276+ let remainder = n_rows % 8 ;
277+
278+ for batch in 0 ..full_batches {
279+ let base_row = batch * 8 ;
280+
281+ // 9 halftone bins × F64x8 accumulators (8 rows per lane)
282+ let mut half_sum = [ F64x8 :: splat ( 0.0 ) ; 9 ] ;
283+ let mut half_count = [ 0u32 ; 9 ] ; // same count for all 8 rows (same n_cols)
284+
285+ let mut octave = 0 ;
286+ while octave < n_octaves {
287+ for hi in 0 ..9 {
288+ let dim = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
289+ if dim < n_cols {
290+ // Gather 8 BF16 values (one per row) at column `dim`
291+ let vals = F64x8 :: from_array ( [
292+ bf16_to_f64 ( buf[ ( base_row + 0 ) * n_cols + dim] ) ,
293+ bf16_to_f64 ( buf[ ( base_row + 1 ) * n_cols + dim] ) ,
294+ bf16_to_f64 ( buf[ ( base_row + 2 ) * n_cols + dim] ) ,
295+ bf16_to_f64 ( buf[ ( base_row + 3 ) * n_cols + dim] ) ,
296+ bf16_to_f64 ( buf[ ( base_row + 4 ) * n_cols + dim] ) ,
297+ bf16_to_f64 ( buf[ ( base_row + 5 ) * n_cols + dim] ) ,
298+ bf16_to_f64 ( buf[ ( base_row + 6 ) * n_cols + dim] ) ,
299+ bf16_to_f64 ( buf[ ( base_row + 7 ) * n_cols + dim] ) ,
300+ ] ) ;
301+ half_sum[ hi] = half_sum[ hi] + vals;
302+ if batch == 0 || octave == 0 {
303+ // Count is same for all batches with same n_cols
304+ }
305+ half_count[ hi] += 1 ;
306+ }
307+ }
308+ octave += octave_stride;
309+ }
310+
311+ // Finalize: convert 9 SIMD accumulators → 8 Base17 results
312+ // Even bins: mean × FP_SCALE, clamped to i16
313+ let mut even_dims = [ [ 0i16 ; BASE_DIM ] ; 8 ] ;
314+
315+ for hi in 0 ..9 {
316+ if half_count[ hi] > 0 {
317+ let count_v = F64x8 :: splat ( half_count[ hi] as f64 ) ;
318+ let scale_v = F64x8 :: splat ( FP_SCALE ) ;
319+ let mean_v = half_sum[ hi] / count_v;
320+ let scaled = mean_v * scale_v;
321+ let arr = scaled. to_array ( ) ;
322+ let bin = HALFTONE_TO_BIN [ hi] as usize ;
323+ for lane in 0 ..8 {
324+ even_dims[ lane] [ bin] =
325+ arr[ lane] . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
326+ }
327+ }
328+ }
329+
330+ // Odd bins: interpolate from neighbors
331+ for lane in 0 ..8 {
332+ for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
333+ let left = even_dims[ lane] [ odd - 1 ] as i32 ;
334+ let right = even_dims[ lane] [ ( odd + 1 ) % BASE_DIM ] as i32 ;
335+ even_dims[ lane] [ odd] = ( ( left + right) / 2 ) as i16 ;
336+ }
337+ result. push ( Base17 { dims : even_dims[ lane] } ) ;
338+ }
339+ }
340+
341+ // Scalar tail for remaining rows (< 8)
342+ for r in ( full_batches * 8 ) ..n_rows {
343+ let start = r * n_cols;
344+ let end = ( start + n_cols) . min ( buf. len ( ) ) ;
345+ result. push ( project_row_bf16_strided ( & buf[ start..end] , octave_stride) ) ;
346+ }
347+
348+ result
349+ }
350+
254351/// Read a BF16 tensor as raw u16 values. NO f32 conversion.
255352/// `buf` is reusable — caller allocates once, passes to every tensor.
256353pub fn read_tensor_bf16_raw < R : Read + Seek > (
@@ -346,18 +443,19 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
346443 let n_elements = read_tensor_bf16_raw ( reader, & gguf_header, tensor, & mut bf16_buf) ?;
347444 let ( n_rows, n_cols) = tensor_to_rows_dims ( & tensor. dimensions , & layer_type) ;
348445
349- let mut rows = Vec :: with_capacity ( n_rows) ;
350- for r in 0 ..n_rows {
351- let start = r * n_cols;
352- let end = ( start + n_cols) . min ( n_elements) ;
353- let row_slice = & bf16_buf[ start..end] ;
354- let b17 = if octave_stride > 1 {
355- project_row_bf16_strided ( row_slice, octave_stride)
356- } else {
357- project_row_bf16_direct ( row_slice)
358- } ;
359- rows. push ( b17) ;
360- }
446+ // F64x8: 8 rows parallel, SIMD accumulation per halftone bin
447+ let rows = if octave_stride > 1 {
448+ project_tensor_bf16_simd ( & bf16_buf[ ..n_elements] , n_rows, n_cols, octave_stride)
449+ } else {
450+ // Full precision: scalar per-row (stride=1 doesn't benefit from SIMD halftone)
451+ let mut rows = Vec :: with_capacity ( n_rows) ;
452+ for r in 0 ..n_rows {
453+ let start = r * n_cols;
454+ let end = ( start + n_cols) . min ( n_elements) ;
455+ rows. push ( project_row_bf16_direct ( & bf16_buf[ start..end] ) ) ;
456+ }
457+ rows
458+ } ;
361459
362460 let orig_bytes = ( n_rows * n_cols * 4 ) as u64 ;
363461 let comp_bytes = ( rows. len ( ) * Base17 :: BYTE_SIZE ) as u64 ;
0 commit comments