@@ -388,18 +388,8 @@ fn log_softmax_f32_scalar(x: ArrayView1<f32>, mut out: ArrayViewMut1<f32>) {
388388/// assert!((out[[1, 0]] - 1.0 / 3.0).abs() < 1e-5);
389389/// ```
390390pub fn softmax_axis_f32 ( x : ArrayView2 < f32 > , mut out : ArrayViewMut2 < f32 > , axis : Axis ) {
391- assert ! (
392- axis. index( ) < 2 ,
393- "softmax_axis_f32: axis {} is out of bounds for a 2-D array" ,
394- axis. index( )
395- ) ;
396- assert_eq ! (
397- x. shape( ) ,
398- out. shape( ) ,
399- "softmax_axis_f32: shape mismatch (x={:?} out={:?})" ,
400- x. shape( ) ,
401- out. shape( )
402- ) ;
391+ assert ! ( axis. index( ) < 2 , "softmax_axis_f32: axis {} is out of bounds for a 2-D array" , axis. index( ) ) ;
392+ assert_eq ! ( x. shape( ) , out. shape( ) , "softmax_axis_f32: shape mismatch (x={:?} out={:?})" , x. shape( ) , out. shape( ) ) ;
403393 // `lanes(axis)` yields 1-D views ALONG `axis`; `lanes_mut(axis)` yields
404394 // the corresponding mutable 1-D views of `out`. Zipping them visits every
405395 // lane exactly once.
@@ -450,11 +440,7 @@ pub fn softmax_axis_f32(x: ArrayView2<f32>, mut out: ArrayViewMut2<f32>, axis: A
450440/// assert!((out[[1, 2]] - expected).abs() < 1e-5);
451441/// ```
452442pub fn log_softmax_axis_f32 ( x : ArrayView2 < f32 > , mut out : ArrayViewMut2 < f32 > , axis : Axis ) {
453- assert ! (
454- axis. index( ) < 2 ,
455- "log_softmax_axis_f32: axis {} is out of bounds for a 2-D array" ,
456- axis. index( )
457- ) ;
443+ assert ! ( axis. index( ) < 2 , "log_softmax_axis_f32: axis {} is out of bounds for a 2-D array" , axis. index( ) ) ;
458444 assert_eq ! (
459445 x. shape( ) ,
460446 out. shape( ) ,
@@ -776,7 +762,13 @@ mod tests {
776762 softmax_f32 ( row. view ( ) , out_1d. view_mut ( ) ) ;
777763
778764 for j in 0 ..4 {
779- assert ! ( ( out_axis[ [ 0 , j] ] - out_1d[ j] ) . abs( ) < 1e-6 , "j={}: axis={} vs 1d={}" , j, out_axis[ [ 0 , j] ] , out_1d[ j] ) ;
765+ assert ! (
766+ ( out_axis[ [ 0 , j] ] - out_1d[ j] ) . abs( ) < 1e-6 ,
767+ "j={}: axis={} vs 1d={}" ,
768+ j,
769+ out_axis[ [ 0 , j] ] ,
770+ out_1d[ j]
771+ ) ;
780772 }
781773 }
782774
@@ -850,7 +842,10 @@ mod tests {
850842 assert ! (
851843 ( out_axis[ [ i, j] ] - out_1d[ j] ) . abs( ) < 1e-5 ,
852844 "row={} j={}: axis={} vs 1d={}" ,
853- i, j, out_axis[ [ i, j] ] , out_1d[ j]
845+ i,
846+ j,
847+ out_axis[ [ i, j] ] ,
848+ out_1d[ j]
854849 ) ;
855850 }
856851 }
0 commit comments