diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index baab908..fff3d74 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -323,6 +323,7 @@ function _forward_eval( f.partials_storage[rhs] = zero(T) end end + @assert f.sizes.ndims[1] == 0 "Final result must be scalar, got ndims = $(f.sizes.ndims[1])" return f.forward_storage[1] end