Skip to content

Commit 5ac1aec

Browse files
committed
fix(reductions): mean_f32/f64 return Option per A10 spec
Spec for sprint A10 calls for: pub fn mean_f32(s: &[f32]) -> Option<f32>; // None on empty Previously mean_f32/mean_f64 panicked on empty input. This change returns None for empty slices, matching argmax_f32 / max_f32 / min_f32 which already use the Option convention. Tests: - mean_f32_empty_is_none — verifies None on empty input - mean_f64_empty_is_none — verifies None on empty input - mean_f32_basic — non-empty case via .expect()
1 parent 6a52b78 commit 5ac1aec

1 file changed

Lines changed: 34 additions & 42 deletions

File tree

src/hpc/reductions.rs

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
//!
1616
//! # Empty-slice convention
1717
//!
18-
//! Unbounded reductions (`max`, `min`, `argmax`, `argmin`) return
19-
//! [`Option`]. Sums and norms have a well-defined zero element and
20-
//! return `0.0` (or `0.0` for `nrm2`). [`mean_f32`] panics on an empty
21-
//! slice — there is no meaningful answer.
18+
//! Unbounded reductions (`max`, `min`, `argmax`, `argmin`, `mean`) return
19+
//! [`Option`] — they have no defined value on an empty slice. Sums and
20+
//! norms have a well-defined zero element and return `0.0`.
2221
//!
2322
//! # Numerical notes
2423
//!
@@ -100,24 +99,24 @@ pub fn sum_f64(s: &[f64]) -> f64 {
10099
sum
101100
}
102101

103-
/// Arithmetic mean of all elements.
104-
///
105-
/// # Panics
106-
/// Panics if `s` is empty.
102+
/// Arithmetic mean of all elements. Returns `None` for an empty slice.
107103
#[inline]
108-
pub fn mean_f32(s: &[f32]) -> f32 {
109-
assert!(!s.is_empty(), "mean_f32: empty slice has no mean");
110-
sum_f32(s) / s.len() as f32
104+
pub fn mean_f32(s: &[f32]) -> Option<f32> {
105+
if s.is_empty() {
106+
None
107+
} else {
108+
Some(sum_f32(s) / s.len() as f32)
109+
}
111110
}
112111

113-
/// Arithmetic mean of all elements as `f64`.
114-
///
115-
/// # Panics
116-
/// Panics if `s` is empty.
112+
/// Arithmetic mean of all elements as `f64`. Returns `None` for an empty slice.
117113
#[inline]
118-
pub fn mean_f64(s: &[f64]) -> f64 {
119-
assert!(!s.is_empty(), "mean_f64: empty slice has no mean");
120-
sum_f64(s) / s.len() as f64
114+
pub fn mean_f64(s: &[f64]) -> Option<f64> {
115+
if s.is_empty() {
116+
None
117+
} else {
118+
Some(sum_f64(s) / s.len() as f64)
119+
}
121120
}
122121

123122
// ===========================================================================
@@ -229,10 +228,7 @@ pub fn argmax_f32(s: &[f32]) -> Option<usize> {
229228
best_vals = mask.select(v, best_vals);
230229
// Update indices via f32 bit-blend (U32x16 has no native blend
231230
// helper but f32-mask blend is bit-exact for any 32-bit pattern).
232-
let new_idx_f = mask.select(
233-
F32x16::from_bits(current_indices),
234-
F32x16::from_bits(best_idx_bits),
235-
);
231+
let new_idx_f = mask.select(F32x16::from_bits(current_indices), F32x16::from_bits(best_idx_bits));
236232
best_idx_bits = new_idx_f.to_bits();
237233

238234
// Advance lane indices by 16 for the next chunk.
@@ -297,10 +293,7 @@ pub fn argmin_f32(s: &[f32]) -> Option<usize> {
297293

298294
let mask = v.simd_lt(best_vals);
299295
best_vals = mask.select(v, best_vals);
300-
let new_idx_f = mask.select(
301-
F32x16::from_bits(current_indices),
302-
F32x16::from_bits(best_idx_bits),
303-
);
296+
let new_idx_f = mask.select(F32x16::from_bits(current_indices), F32x16::from_bits(best_idx_bits));
304297
best_idx_bits = new_idx_f.to_bits();
305298
current_indices = current_indices + lane_step;
306299
}
@@ -435,13 +428,18 @@ mod tests {
435428
#[test]
436429
fn mean_f32_basic() {
437430
let v = [1.0_f32, 2.0, 3.0, 4.0];
438-
assert!((mean_f32(&v) - 2.5).abs() < 1e-6);
431+
let m = mean_f32(&v).expect("non-empty");
432+
assert!((m - 2.5).abs() < 1e-6);
433+
}
434+
435+
#[test]
436+
fn mean_f32_empty_is_none() {
437+
assert_eq!(mean_f32(&[]), None);
439438
}
440439

441440
#[test]
442-
#[should_panic(expected = "empty slice has no mean")]
443-
fn mean_f32_empty_panics() {
444-
let _ = mean_f32(&[]);
441+
fn mean_f64_empty_is_none() {
442+
assert_eq!(mean_f64(&[]), None);
445443
}
446444

447445
// ---- max_f32 / min_f32 ------------------------------------------------
@@ -478,7 +476,9 @@ mod tests {
478476
#[test]
479477
fn max_min_misaligned() {
480478
for &n in &[1_usize, 7, 16, 17, 31, 33, 64, 65, 127, 256, 1023] {
481-
let v: Vec<f32> = (0..n).map(|i| ((i as i32) - (n as i32) / 2) as f32).collect();
479+
let v: Vec<f32> = (0..n)
480+
.map(|i| ((i as i32) - (n as i32) / 2) as f32)
481+
.collect();
482482
let expected_max = v.iter().copied().fold(f32::NEG_INFINITY, f32::max);
483483
let expected_min = v.iter().copied().fold(f32::INFINITY, f32::min);
484484
assert_eq!(max_f32(&v), Some(expected_max), "max_f32 n={}", n);
@@ -513,17 +513,9 @@ mod tests {
513513
#[test]
514514
fn argmax_f32_misaligned_tail() {
515515
// Place the maximum at a position straddling the SIMD/tail boundary.
516-
for &(n, peak) in &[
517-
(17_usize, 16),
518-
(17, 0),
519-
(17, 8),
520-
(33, 32),
521-
(33, 17),
522-
(65, 64),
523-
(65, 32),
524-
(127, 100),
525-
(1000, 999),
526-
] {
516+
for &(n, peak) in
517+
&[(17_usize, 16), (17, 0), (17, 8), (33, 32), (33, 17), (65, 64), (65, 32), (127, 100), (1000, 999)]
518+
{
527519
let mut v: Vec<f32> = vec![0.0; n];
528520
v[peak] = 1.0;
529521
assert_eq!(argmax_f32(&v), Some(peak), "n={}, peak={}", n, peak);

0 commit comments

Comments
 (0)