Skip to content

Commit 9c19a20

Browse files
authored
Merge pull request #105 from AdaWorldAPI/claude/teleport-session-setup-wMZfb
feat(hpc): WHT polyfill, i2 quantization, pub kmeans
2 parents 6609f10 + eb27a23 commit 9c19a20

7 files changed

Lines changed: 184 additions & 6 deletions

File tree

.claude/agents/l3-strategist.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description: >
66
mapping rustynum capabilities to ndarray's trait system, identifying
77
architectural gaps, or planning multi-phase implementation roadmaps.
88
tools: Read, Glob, Grep, Bash
9-
model: sonnet
9+
model: opus
1010
---
1111

1212
You are the L3_STRATEGIST for Project NDARRAY Expansion.

.claude/agents/migration-tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description: >
55
Updates blackboard. Identifies gaps. Prevents duplication.
66
READ ONLY — never writes code.
77
tools: Read, Glob, Grep, Bash
8-
model: sonnet
8+
model: opus
99
---
1010

1111
# Migration Tracker

.claude/agents/product-engineer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description: >
66
finalizing API surface, writing doc comments, designing error types,
77
managing feature flags, or ensuring the crate is publishable.
88
tools: Read, Glob, Grep, Bash, Edit, Write
9-
model: sonnet
9+
model: opus
1010
---
1111

1212
You are the PRODUCT_ENGINEER for Project NDARRAY Expansion.

.claude/agents/vector-synthesis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description: >
66
similarity search kernels, ndarray↔vector store bridges,
77
distance metrics (cosine, L2, dot product), or batch vector operations.
88
tools: Read, Glob, Grep, Bash, Edit
9-
model: sonnet
9+
model: opus
1010
---
1111

1212
You are the VECTOR_SYNTHESIS_EXPERT for Project NDARRAY Expansion.

src/hpc/cam_pq.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ pub fn train_hybrid(
460460
/// For 16D subvectors (CAM-PQ subspace dimension), this is one F32x16
461461
/// load-subtract-multiply-reduce. Consumer never sees hardware details.
462462
#[inline(always)]
463-
fn squared_l2(a: &[f32], b: &[f32]) -> f32 {
463+
pub fn squared_l2(a: &[f32], b: &[f32]) -> f32 {
464464
debug_assert_eq!(a.len(), b.len());
465465
let n = a.len();
466466

@@ -518,7 +518,7 @@ fn jaccard_similarity(a: &[String], b: &[String]) -> f32 {
518518
/// Simple k-means clustering.
519519
///
520520
/// Returns `k` centroid vectors of length `dim`.
521-
fn kmeans(data: &[Vec<f32>], k: usize, dim: usize, iterations: usize) -> Vec<Vec<f32>> {
521+
pub fn kmeans(data: &[Vec<f32>], k: usize, dim: usize, iterations: usize) -> Vec<Vec<f32>> {
522522
let n = data.len();
523523
if n == 0 || k == 0 {
524524
return vec![vec![0.0; dim]; k];

src/hpc/fft.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,103 @@ pub fn rfft_f32(input: &[f32]) -> Vec<f32> {
136136
complex[..2 * out_len].to_vec()
137137
}
138138

139+
// ── Walsh-Hadamard Transform ──────────────────────────────────────
140+
//
141+
// The WHT is to quantization codecs what FFT is to signal processing:
142+
// an O(n log n) orthogonal rotation that spreads energy uniformly
143+
// across all coefficients. Unlike SVD (data-adaptive, O(n²k) training),
144+
// the Hadamard rotation is deterministic, free, and self-inverse.
145+
//
146+
// Used by the HadCascade codec (bgz-tensor) for residual rotation
147+
// before i4/i2 quantization. ICC 1.0000 on real model weights.
148+
149+
/// In-place Walsh-Hadamard Transform (normalized, self-inverse).
150+
///
151+
/// `data` length must be a power of 2. After transform, `||WHT(x)|| = ||x||`
152+
/// (energy-preserving). Applying WHT twice returns the original vector.
153+
///
154+
/// SIMD: uses F32x16 butterfly for blocks ≥ 16 elements.
155+
///
156+
/// # Example
157+
///
158+
/// ```
159+
/// use ndarray::hpc::fft::wht_f32;
160+
///
161+
/// let mut x = vec![1.0f32, 0.0, 0.0, 0.0];
162+
/// wht_f32(&mut x);
163+
/// assert!((x[0] - 0.5).abs() < 1e-6); // 1/sqrt(4) * 1 = 0.5
164+
///
165+
/// // Self-inverse: WHT(WHT(x)) = x
166+
/// wht_f32(&mut x);
167+
/// assert!((x[0] - 1.0).abs() < 1e-5);
168+
/// ```
169+
pub fn wht_f32(data: &mut [f32]) {
170+
let n = data.len();
171+
assert!(n.is_power_of_two(), "WHT length must be a power of 2");
172+
173+
let mut h = 1;
174+
while h < n {
175+
if h >= 16 {
176+
wht_butterfly_simd(data, n, h);
177+
} else {
178+
for i in (0..n).step_by(h * 2) {
179+
for j in i..i + h {
180+
let x = data[j];
181+
let y = data[j + h];
182+
data[j] = x + y;
183+
data[j + h] = x - y;
184+
}
185+
}
186+
}
187+
h *= 2;
188+
}
189+
190+
let norm = 1.0 / (n as f32).sqrt();
191+
let mut i = 0;
192+
while i + 16 <= n {
193+
use crate::simd::F32x16;
194+
let v = F32x16::from_slice(&data[i..]);
195+
let scaled = v * F32x16::splat(norm);
196+
scaled.copy_to_slice(&mut data[i..i + 16]);
197+
i += 16;
198+
}
199+
while i < n {
200+
data[i] *= norm;
201+
i += 1;
202+
}
203+
}
204+
205+
/// WHT butterfly for one level, SIMD-accelerated for h ≥ 16.
206+
fn wht_butterfly_simd(data: &mut [f32], n: usize, h: usize) {
207+
use crate::simd::F32x16;
208+
for i in (0..n).step_by(h * 2) {
209+
let mut j = 0;
210+
while j + 16 <= h {
211+
let a = F32x16::from_slice(&data[i + j..]);
212+
let b = F32x16::from_slice(&data[i + j + h..]);
213+
let sum = a + b;
214+
let diff = a - b;
215+
sum.copy_to_slice(&mut data[i + j..i + j + 16]);
216+
diff.copy_to_slice(&mut data[i + j + h..i + j + h + 16]);
217+
j += 16;
218+
}
219+
while j < h {
220+
let x = data[i + j];
221+
let y = data[i + j + h];
222+
data[i + j] = x + y;
223+
data[i + j + h] = x - y;
224+
j += 1;
225+
}
226+
}
227+
}
228+
229+
/// Convenience: WHT on a new vector (non-mutating).
230+
pub fn wht_f32_new(input: &[f32]) -> Vec<f32> {
231+
let mut out = input.to_vec();
232+
wht_f32(&mut out);
233+
out
234+
}
235+
139236
// ── Helpers ────────────────────────────────────────────────────────
140237

141238
fn bit_reverse_f32(data: &mut [f32], n: usize) {
@@ -197,6 +294,44 @@ mod tests {
197294
}
198295
}
199296

297+
#[test]
298+
fn test_wht_self_inverse() {
299+
let original = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
300+
let mut data = original.clone();
301+
wht_f32(&mut data);
302+
wht_f32(&mut data);
303+
for (a, b) in original.iter().zip(data.iter()) {
304+
assert!((a - b).abs() < 1e-5, "self-inverse: {} vs {}", a, b);
305+
}
306+
}
307+
308+
#[test]
309+
fn test_wht_energy_preservation() {
310+
let mut data = vec![1.0f32, -2.0, 3.0, -4.0];
311+
let norm_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
312+
wht_f32(&mut data);
313+
let norm_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
314+
assert!((norm_before - norm_after).abs() < 1e-4,
315+
"energy: {} vs {}", norm_before, norm_after);
316+
}
317+
318+
#[test]
319+
fn test_wht_large_simd() {
320+
let mut data: Vec<f32> = (0..1024).map(|i| (i as f32 * 0.618).sin()).collect();
321+
let original = data.clone();
322+
wht_f32(&mut data);
323+
// Norm preservation at 1024-d (hits SIMD path)
324+
let n_orig: f32 = original.iter().map(|x| x * x).sum::<f32>().sqrt();
325+
let n_wht: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
326+
assert!((n_orig - n_wht).abs() / n_orig < 1e-4,
327+
"SIMD WHT norm: {} vs {}", n_orig, n_wht);
328+
// Self-inverse
329+
wht_f32(&mut data);
330+
let max_err = original.iter().zip(data.iter())
331+
.map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
332+
assert!(max_err < 1e-3, "SIMD self-inverse max_err: {}", max_err);
333+
}
334+
200335
#[test]
201336
fn test_rfft() {
202337
let input = vec![1.0f32, 2.0, 3.0, 4.0];

src/hpc/quantized.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,49 @@ pub fn dequantize_i4_to_f32(packed: &[u8], params: &QuantParams, len: usize) ->
374374
result
375375
}
376376

377+
/// Quantize f32 to i2 (packed: four i2 values per byte, signed ±1).
378+
///
379+
/// Each value is clamped to {-1, 0, +1} after scaling by abs_max.
380+
/// Packing: 4 crumbs per byte, low bits first.
381+
/// Symmetric quantization with zero_point = 0.
382+
pub fn quantize_f32_to_i2(data: &[f32]) -> (Vec<u8>, QuantParams) {
383+
let abs_max = data.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
384+
let scale = if abs_max > 0.0 { abs_max } else { 1.0 };
385+
386+
let packed_len = (data.len() + 3) / 4;
387+
let mut packed = vec![0u8; packed_len];
388+
389+
for (i, &v) in data.iter().enumerate() {
390+
let q = (v / scale).round().clamp(-1.0, 1.0) as i8;
391+
let u = (q + 1) as u8; // map {-1,0,1} to {0,1,2}
392+
let shift = (i % 4) * 2;
393+
packed[i / 4] |= (u & 0x03) << shift;
394+
}
395+
396+
(
397+
packed,
398+
QuantParams {
399+
scale,
400+
zero_point: 0,
401+
min_val: -abs_max,
402+
max_val: abs_max,
403+
},
404+
)
405+
}
406+
407+
/// Dequantize i2 (packed) to f32.
408+
pub fn dequantize_i2_to_f32(packed: &[u8], params: &QuantParams, len: usize) -> Vec<f32> {
409+
let mut result = Vec::with_capacity(len);
410+
for i in 0..len {
411+
let byte = packed[i / 4];
412+
let shift = (i % 4) * 2;
413+
let u = (byte >> shift) & 0x03;
414+
let q = u as i8 - 1; // map {0,1,2} back to {-1,0,1}
415+
result.push(q as f32 * params.scale);
416+
}
417+
result
418+
}
419+
377420
#[cfg(test)]
378421
mod tests {
379422
use super::*;

0 commit comments

Comments
 (0)