Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
:norm,
:sum,
:row,
:reduce,
]

function _validate_register_assumptions(
Expand Down
34 changes: 33 additions & 1 deletion src/parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ end

function _parse_expression(stack, data, expr, x, parent_index)
if Meta.isexpr(x, :call)
if length(x.args) == 2 && !Meta.isexpr(x.args[2], :...)
if x.args[1] == :reduce
_parse_reduce_expression(stack, data, expr, x, parent_index)
elseif length(x.args) == 2 && !Meta.isexpr(x.args[2], :...)
_parse_univariate_expression(stack, data, expr, x, parent_index)
else
# The call is either n-ary, or it is a splat, in which case we
Expand Down Expand Up @@ -278,6 +280,36 @@ function _parse_vcat_expression(
return
end

function _parse_reduce_expression(stack, data, expr, x, parent_index)
if length(x.args) != 3
error("Unsupported reduce expression: $x. Expected reduce(op, collection).")
end

op = x.args[2]
collection = x.args[3]

if !Meta.isexpr(collection, :vect)
error("Unsupported reduce collection: $collection. Expected a vector literal.")
end

args = collection.args

if isempty(args)
error("Unsupported reduce on empty collection.")
elseif length(args) == 1
push!(stack, (parent_index, args[1]))
return
end

folded = Expr(:call, op, args[1], args[2])
for i in 3:length(args)
folded = Expr(:call, op, folded, args[i])
end

push!(stack, (parent_index, folded))
return
end

function _parse_inequality_expression(
stack::Vector{Tuple{Int,Any}},
data::Model,
Expand Down
74 changes: 74 additions & 0 deletions test/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,80 @@ function test_objective_broadcasted_tanh()
return
end

function test_objective_reduce_sum()
model = ArrayDiff.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
ArrayDiff.set_objective(model, :(reduce(+, [$x1, $x2, $x3])))
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 0, 0, 0, 0]
@test sizes.size_offset == [0, 0, 0, 0, 0]
@test sizes.size == []
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
x1 = 1.0
x2 = 2.0
x3 = 3.0
@test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0
g = ones(3)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
@test g == [1.0, 1.0, 1.0]
return
end

function test_objective_reduce_prod()
model = ArrayDiff.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
ArrayDiff.set_objective(model, :(reduce(*, [$x1, $x2, $x3])))
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 0, 0, 0, 0]
@test sizes.size_offset == [0, 0, 0, 0, 0]
@test sizes.size == []
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
x1 = 1.0
x2 = 2.0
x3 = 3.0
@test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0
g = ones(3)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
@test g == [6.0 / x1, 6.0 / x2, 6.0 / x3]
return
end

function test_objective_reduce_atan()
model = ArrayDiff.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
ArrayDiff.set_objective(model, :(reduce(atan, [$x1, $x2, $x3])))
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 0, 0, 0, 0]
@test sizes.size_offset == [0, 0, 0, 0, 0]
@test sizes.size == []
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
x1 = 1.0
x2 = 2.0
x3 = 3.0
@test MOI.eval_objective(evaluator, [x1, x2, x3]) ==
atan(atan(x1, x2), x3)
g = ones(3)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
@test g ≈ [
x2 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)),
-x1 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)),
-atan(x1, x2) / (x3^2 + atan(x1, x2)^2),
]
return
end

end # module

TestArrayDiff.runtests()
Loading