diff --git a/src/operators.jl b/src/operators.jl index 7a88b9f..93dc98b 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -18,6 +18,7 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [ :norm, :sum, :row, + :reduce, ] function _validate_register_assumptions( diff --git a/src/parse.jl b/src/parse.jl index 8382d2d..e6686e6 100644 --- a/src/parse.jl +++ b/src/parse.jl @@ -158,7 +158,9 @@ end function _parse_expression(stack, data, expr, x, parent_index) if Meta.isexpr(x, :call) - if length(x.args) == 2 && !Meta.isexpr(x.args[2], :...) + if x.args[1] == :reduce + _parse_reduce_expression(stack, data, expr, x, parent_index) + elseif length(x.args) == 2 && !Meta.isexpr(x.args[2], :...) _parse_univariate_expression(stack, data, expr, x, parent_index) else # The call is either n-ary, or it is a splat, in which case we @@ -278,6 +280,36 @@ function _parse_vcat_expression( return end +function _parse_reduce_expression(stack, data, expr, x, parent_index) + if length(x.args) != 3 + error("Unsupported reduce expression: $x. Expected reduce(op, collection).") + end + + op = x.args[2] + collection = x.args[3] + + if !Meta.isexpr(collection, :vect) + error("Unsupported reduce collection: $collection. Expected a vector literal.") + end + + args = collection.args + + if isempty(args) + error("Unsupported reduce on empty collection.") + elseif length(args) == 1 + push!(stack, (parent_index, args[1])) + return + end + + folded = Expr(:call, op, args[1], args[2]) + for i in 3:length(args) + folded = Expr(:call, op, folded, args[i]) + end + + push!(stack, (parent_index, folded)) + return +end + function _parse_inequality_expression( stack::Vector{Tuple{Int,Any}}, data::Model, diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 2a02822..c70fa87 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -634,6 +634,80 @@ function test_objective_broadcasted_tanh() return end +function test_objective_reduce_sum() + model = ArrayDiff.Model() + x1 = MOI.VariableIndex(1) + x2 = MOI.VariableIndex(2) + x3 = MOI.VariableIndex(3) + ArrayDiff.set_objective(model, :(reduce(+, [$x1, $x2, $x3]))) + evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3]) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + @test sizes.ndims == [0, 0, 0, 0, 0] + @test sizes.size_offset == [0, 0, 0, 0, 0] + @test sizes.size == [] + @test sizes.storage_offset == [0, 1, 2, 3, 4, 5] + x1 = 1.0 + x2 = 2.0 + x3 = 3.0 + @test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0 + g = ones(3) + MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3]) + @test g == [1.0, 1.0, 1.0] + return +end + +function test_objective_reduce_prod() + model = ArrayDiff.Model() + x1 = MOI.VariableIndex(1) + x2 = MOI.VariableIndex(2) + x3 = MOI.VariableIndex(3) + ArrayDiff.set_objective(model, :(reduce(*, [$x1, $x2, $x3]))) + evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3]) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + @test sizes.ndims == [0, 0, 0, 0, 0] + @test sizes.size_offset == [0, 0, 0, 0, 0] + @test sizes.size == [] + @test sizes.storage_offset == [0, 1, 2, 3, 4, 5] + x1 = 1.0 + x2 = 2.0 + x3 = 3.0 + @test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0 + g = ones(3) + MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3]) + @test g == [6.0 / x1, 6.0 / x2, 6.0 / x3] + return +end + +function test_objective_reduce_atan() + model = ArrayDiff.Model() + x1 = MOI.VariableIndex(1) + x2 = MOI.VariableIndex(2) + x3 = MOI.VariableIndex(3) + ArrayDiff.set_objective(model, :(reduce(atan, [$x1, $x2, $x3]))) + evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3]) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + @test sizes.ndims == [0, 0, 0, 0, 0] + @test sizes.size_offset == [0, 0, 0, 0, 0] + @test sizes.size == [] + @test sizes.storage_offset == [0, 1, 2, 3, 4, 5] + x1 = 1.0 + x2 = 2.0 + x3 = 3.0 + @test MOI.eval_objective(evaluator, [x1, x2, x3]) == + atan(atan(x1, x2), x3) + g = ones(3) + MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3]) + @test g ≈ [ + x2 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)), + -x1 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)), + -atan(x1, x2) / (x3^2 + atan(x1, x2)^2), + ] + return +end + end # module TestArrayDiff.runtests()