Skip to content

Commit b5a63fc

Browse files
committed
fix(hpc/activations): sigmoid_f32 stride mismatch (Codex PR #154)
Codex flagged: same-shaped contiguous views with different memory orders (C-order input + F-order output) both succeeded at as_slice_memory_order but with mismatched logical indexing — the flat SIMD primitive wrote sigmoid values into the wrong output coordinates. Fix: add the same strides-equality guard that hpc/vml.rs already uses in dispatch_unary_contig / dispatch_binary_contig. Mismatched-stride inputs now route to the stride-aware Zip cold path. Adds test_sigmoid_f32_c_in_f_out_mismatched_strides regression: 2x2 C-order input, F-order zero-init output, asserts logical coordinates carry correct sigmoid values. Activations test count: 16 -> 17. Reductions are unaffected (read-only commutative/associative — memory order doesn't change the scalar result). vml unary/binary already guarded via dispatch_*_contig.
1 parent 66a8a81 commit b5a63fc

1 file changed

Lines changed: 39 additions & 4 deletions

File tree

src/hpc/activations.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,19 @@ where
9393
/// ```
9494
pub fn sigmoid_f32<D: Dimension>(x: ArrayView<f32, D>, mut out: ArrayViewMut<f32, D>) {
9595
assert_eq!(x.shape(), out.shape(), "sigmoid_f32: shape mismatch (x={:?} out={:?})", x.shape(), out.shape());
96-
if let (Some(xs), Some(os)) = (x.as_slice_memory_order(), out.as_slice_memory_order_mut()) {
97-
sigmoid_f32_slice(xs, os);
98-
return;
96+
// HOT PATH guard: input + output must share strides AND each be contiguous
97+
// in their own memory order. Without the strides-equality check a C-order
98+
// input + F-order output (same shape, both individually contiguous) would
99+
// both succeed at `as_slice_memory_order` but with mismatched logical
100+
// indexing — writing the wrong sigmoid value into each output coordinate.
101+
// Matches the dispatch_unary_contig guard in `hpc/vml.rs`.
102+
if x.strides() == out.strides() {
103+
if let (Some(xs), Some(os)) = (x.as_slice_memory_order(), out.as_slice_memory_order_mut()) {
104+
sigmoid_f32_slice(xs, os);
105+
return;
106+
}
99107
}
100-
// Cold path: non-contiguous views (sliced/transposed) — stride-aware scalar.
108+
// Cold path: non-contiguous views OR mismatched memory orders — stride-aware scalar.
101109
Zip::from(&mut out)
102110
.and(x)
103111
.for_each(|o, &v| *o = 1.0 / (1.0 + (-v).exp()));
@@ -406,6 +414,33 @@ mod tests {
406414
}
407415
}
408416

417+
#[test]
418+
fn test_sigmoid_f32_c_in_f_out_mismatched_strides() {
419+
// Regression for Codex PR #154 finding: same-shaped contig views with
420+
// different memory orders (C-order input + F-order output) both pass
421+
// `as_slice_memory_order` but with mismatched logical indexing. Without
422+
// the strides-equality guard, the flat SIMD primitive writes sigmoid
423+
// values into the wrong output coordinates. The fix re-routes such
424+
// cases to the stride-aware Zip cold path.
425+
use crate::{Array, Array2, ShapeBuilder};
426+
let x: Array2<f32> = arr2(&[[0.0_f32, 100.0], [-100.0, 0.0]]); // C-order
427+
// F-order output of the same shape, both individually contiguous,
428+
// but `x.strides() != out.strides()`.
429+
let mut out: Array2<f32> = Array::zeros((2, 2).f());
430+
assert!(x.as_slice_memory_order().is_some());
431+
assert!(out.as_slice_memory_order().is_some());
432+
assert_ne!(x.strides(), out.strides(), "test setup: strides must differ");
433+
434+
sigmoid_f32(x.view(), out.view_mut());
435+
436+
// Logical coordinates must carry the right sigmoid values regardless
437+
// of the underlying memory order.
438+
assert!((out[[0, 0]] - 0.5).abs() < 1e-6, "sigmoid(0) at [0,0] = {}", out[[0, 0]]);
439+
assert!((out[[0, 1]] - 1.0).abs() < 1e-4, "sigmoid(100) at [0,1] = {}", out[[0, 1]]);
440+
assert!((out[[1, 0]] - 0.0).abs() < 1e-4, "sigmoid(-100) at [1,0] = {}", out[[1, 0]]);
441+
assert!((out[[1, 1]] - 0.5).abs() < 1e-6, "sigmoid(0) at [1,1] = {}", out[[1, 1]]);
442+
}
443+
409444
#[test]
410445
fn test_sigmoid_f32_2d() {
411446
// Generic-D verification: 2-D contiguous input works

0 commit comments

Comments
 (0)