Skip to content

Commit 763351b

Browse files
committed
F64x8 SIMD 8-row-parallel tensor projection
project_tensor_bf16_simd: processes 8 rows per SIMD batch using F64x8 accumulators (9 halftone bins × 8 lanes = 9 zmm registers). Per octave: 9 gather+vaddpd ops. For 5120-col at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch (vs 2.35M scalar ops). Integrated into stream_index_gguf_bf16 BF16 fast path. Scalar tail handles remainder rows (<8). https://claude.ai/code/session_01HmdXNPit7QsTCfhJFef3Ee
1 parent 8d2d372 commit 763351b

1 file changed

Lines changed: 110 additions & 12 deletions

File tree

src/hpc/gguf_indexer.rs

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,103 @@ pub fn project_row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
251251
Base17 { dims }
252252
}
253253

254+
// ── SIMD 8-row-parallel tensor projection ──
255+
256+
/// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
257+
///
258+
/// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
259+
/// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
260+
///
261+
/// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
262+
/// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
263+
pub fn project_tensor_bf16_simd(
264+
buf: &[u16],
265+
n_rows: usize,
266+
n_cols: usize,
267+
octave_stride: usize,
268+
) -> Vec<Base17> {
269+
use crate::simd::F64x8;
270+
271+
let n_octaves = (n_cols + BASE_DIM - 1) / BASE_DIM;
272+
let mut result = Vec::with_capacity(n_rows);
273+
274+
// Process 8 rows at a time
275+
let full_batches = n_rows / 8;
276+
let remainder = n_rows % 8;
277+
278+
for batch in 0..full_batches {
279+
let base_row = batch * 8;
280+
281+
// 9 halftone bins × F64x8 accumulators (8 rows per lane)
282+
let mut half_sum = [F64x8::splat(0.0); 9];
283+
let mut half_count = [0u32; 9]; // same count for all 8 rows (same n_cols)
284+
285+
let mut octave = 0;
286+
while octave < n_octaves {
287+
for hi in 0..9 {
288+
let dim = octave * BASE_DIM + HALFTONE_POS[hi] as usize;
289+
if dim < n_cols {
290+
// Gather 8 BF16 values (one per row) at column `dim`
291+
let vals = F64x8::from_array([
292+
bf16_to_f64(buf[(base_row + 0) * n_cols + dim]),
293+
bf16_to_f64(buf[(base_row + 1) * n_cols + dim]),
294+
bf16_to_f64(buf[(base_row + 2) * n_cols + dim]),
295+
bf16_to_f64(buf[(base_row + 3) * n_cols + dim]),
296+
bf16_to_f64(buf[(base_row + 4) * n_cols + dim]),
297+
bf16_to_f64(buf[(base_row + 5) * n_cols + dim]),
298+
bf16_to_f64(buf[(base_row + 6) * n_cols + dim]),
299+
bf16_to_f64(buf[(base_row + 7) * n_cols + dim]),
300+
]);
301+
half_sum[hi] = half_sum[hi] + vals;
302+
if batch == 0 || octave == 0 {
303+
// Count is same for all batches with same n_cols
304+
}
305+
half_count[hi] += 1;
306+
}
307+
}
308+
octave += octave_stride;
309+
}
310+
311+
// Finalize: convert 9 SIMD accumulators → 8 Base17 results
312+
// Even bins: mean × FP_SCALE, clamped to i16
313+
let mut even_dims = [[0i16; BASE_DIM]; 8];
314+
315+
for hi in 0..9 {
316+
if half_count[hi] > 0 {
317+
let count_v = F64x8::splat(half_count[hi] as f64);
318+
let scale_v = F64x8::splat(FP_SCALE);
319+
let mean_v = half_sum[hi] / count_v;
320+
let scaled = mean_v * scale_v;
321+
let arr = scaled.to_array();
322+
let bin = HALFTONE_TO_BIN[hi] as usize;
323+
for lane in 0..8 {
324+
even_dims[lane][bin] =
325+
arr[lane].round().clamp(-32768.0, 32767.0) as i16;
326+
}
327+
}
328+
}
329+
330+
// Odd bins: interpolate from neighbors
331+
for lane in 0..8 {
332+
for odd in (1..BASE_DIM).step_by(2) {
333+
let left = even_dims[lane][odd - 1] as i32;
334+
let right = even_dims[lane][(odd + 1) % BASE_DIM] as i32;
335+
even_dims[lane][odd] = ((left + right) / 2) as i16;
336+
}
337+
result.push(Base17 { dims: even_dims[lane] });
338+
}
339+
}
340+
341+
// Scalar tail for remaining rows (< 8)
342+
for r in (full_batches * 8)..n_rows {
343+
let start = r * n_cols;
344+
let end = (start + n_cols).min(buf.len());
345+
result.push(project_row_bf16_strided(&buf[start..end], octave_stride));
346+
}
347+
348+
result
349+
}
350+
254351
/// Read a BF16 tensor as raw u16 values. NO f32 conversion.
255352
/// `buf` is reusable — caller allocates once, passes to every tensor.
256353
pub fn read_tensor_bf16_raw<R: Read + Seek>(
@@ -346,18 +443,19 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
346443
let n_elements = read_tensor_bf16_raw(reader, &gguf_header, tensor, &mut bf16_buf)?;
347444
let (n_rows, n_cols) = tensor_to_rows_dims(&tensor.dimensions, &layer_type);
348445

349-
let mut rows = Vec::with_capacity(n_rows);
350-
for r in 0..n_rows {
351-
let start = r * n_cols;
352-
let end = (start + n_cols).min(n_elements);
353-
let row_slice = &bf16_buf[start..end];
354-
let b17 = if octave_stride > 1 {
355-
project_row_bf16_strided(row_slice, octave_stride)
356-
} else {
357-
project_row_bf16_direct(row_slice)
358-
};
359-
rows.push(b17);
360-
}
446+
// F64x8: 8 rows parallel, SIMD accumulation per halftone bin
447+
let rows = if octave_stride > 1 {
448+
project_tensor_bf16_simd(&bf16_buf[..n_elements], n_rows, n_cols, octave_stride)
449+
} else {
450+
// Full precision: scalar per-row (stride=1 doesn't benefit from SIMD halftone)
451+
let mut rows = Vec::with_capacity(n_rows);
452+
for r in 0..n_rows {
453+
let start = r * n_cols;
454+
let end = (start + n_cols).min(n_elements);
455+
rows.push(project_row_bf16_direct(&bf16_buf[start..end]));
456+
}
457+
rows
458+
};
361459

362460
let orig_bytes = (n_rows * n_cols * 4) as u64;
363461
let comp_bytes = (rows.len() * Base17::BYTE_SIZE) as u64;

0 commit comments

Comments
 (0)