diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 0106b0c..16ef670 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -247,6 +247,28 @@ function _forward_eval( tmp_dot += v1 * v2 end @s f.forward_storage[k] = tmp_dot + elseif node.index == 12 # hcat + idx1, idx2 = children_indices + ix1 = children_arr[idx1] + ix2 = children_arr[idx2] + nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2) + col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1) + for j in _eachindex(f.sizes, ix1) + @j f.partials_storage[ix1] = one(T) + val = @j f.forward_storage[ix1] + @j f.forward_storage[k] = val + end + for j in _eachindex(f.sizes, ix2) + @j f.partials_storage[ix2] = one(T) + val = @j f.forward_storage[ix2] + _setindex!( + f.forward_storage, + val, + f.sizes, + k, + j + nb_cols1 * col_size, + ) + end elseif node.index == 14 # norm ix = children_arr[children_indices[1]] tmp_norm_squared = zero(T) @@ -395,6 +417,50 @@ function _reverse_eval(f::_SubexpressionStorage) end end continue + elseif op == :hcat + idx1, idx2 = children_indices + ix1 = children_arr[idx1] + ix2 = children_arr[idx2] + nb_cols1 = + f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2) + col_size = + f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1) + for j in _eachindex(f.sizes, ix1) + partial = @j f.partials_storage[ix1] + val = ifelse( + _getindex(f.reverse_storage, f.sizes, k, j) == + 0.0 && !isfinite(partial), + _getindex(f.reverse_storage, f.sizes, k, j), + _getindex(f.reverse_storage, f.sizes, k, j) * + partial, + ) + @j f.reverse_storage[ix1] = val + end + for j in _eachindex(f.sizes, ix2) + partial = @j f.partials_storage[ix2] + val = ifelse( + _getindex( + f.reverse_storage, + f.sizes, + k, + j + nb_cols1 * col_size, + ) == 0.0 && !isfinite(partial), + _getindex( + f.reverse_storage, + f.sizes, + k, + j + nb_cols1 * col_size, + ), + _getindex( + f.reverse_storage, + f.sizes, + k, + j + nb_cols1 * col_size, + ) * partial, + ) + @j f.reverse_storage[ix2] = val + end + continue elseif op == :norm # Node `k` is scalar, the jacobian w.r.t. the vectorized input # child is a row vector whose entries are stored in `f.partials_storage` @@ -408,7 +474,7 @@ function _reverse_eval(f::_SubexpressionStorage) rev_parent, rev_parent * partial, ) - @j f.reverse_storage[ix] = val + @j f.reverse_storage[ix] = val end continue end diff --git a/src/sizes.jl b/src/sizes.jl index 9fcb562..47aefbd 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -186,6 +186,25 @@ function _infer_sizes( elseif op == :+ || op == :- # TODO assert all arguments have same size _copy_size!(sizes, k, children_arr[first(children_indices)]) + elseif op == :hcat + total_cols = 0 + for c_idx in children_indices + total_cols += + sizes.ndims[children_arr[c_idx]] <= 1 ? 1 : + _size(sizes, children_arr[c_idx], 2) + end + if sizes.ndims[children_arr[first(children_indices)]] == 0 + shape = (1, total_cols) + else + @assert sizes.ndims[children_arr[first( + children_indices, + )]] <= 2 "Hcat with ndims > 2 is not supported yet" + shape = ( + _size(sizes, children_arr[first(children_indices)], 1), + total_cols, + ) + end + _add_size!(sizes, k, tuple(shape...)) elseif op == :* # TODO assert compatible sizes and all ndims should be 0 or 2 first_matrix = findfirst(children_indices) do i diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index d7b8efe..4bd270b 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -64,6 +64,61 @@ function test_objective_dot_bivariate() return end +function test_objective_hcat_0dim() + model = Nonlinear.Model() + x1 = MOI.VariableIndex(1) + x2 = MOI.VariableIndex(2) + x3 = MOI.VariableIndex(3) + x4 = MOI.VariableIndex(4) + Nonlinear.set_objective(model, :(dot([$x1 $x3], [$x2 $x4]))) + evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4]) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + @test sizes.ndims == [0, 2, 0, 0, 2, 0, 0] + @test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0] + @test sizes.size == [1, 2, 1, 2] + @test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9] + x1 = 1.0 + x2 = 2.0 + x3 = 3.0 + x4 = 4.0 + println(MOI.eval_objective(evaluator, [x1, x2, x3, x4])) + @test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0 + g = ones(4) + MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4]) + @test g == [2.0, 1.0, 4.0, 3.0] + return +end + +function test_objective_hcat_1dim() + model = Nonlinear.Model() + x1 = MOI.VariableIndex(1) + x2 = MOI.VariableIndex(2) + x3 = MOI.VariableIndex(3) + x4 = MOI.VariableIndex(4) + Nonlinear.set_objective( + model, + :(dot(hcat([$x1], [$x3]), hcat([$x2], [$x4]))), + ) + evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4]) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + @test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0] + @test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0] + @test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2] + @test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13] + x1 = 1.0 + x2 = 2.0 + x3 = 3.0 + x4 = 4.0 + println(MOI.eval_objective(evaluator, [x1, x2, x3, x4])) + @test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0 + g = ones(4) + MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4]) + @test g == [2.0, 1.0, 4.0, 3.0] + return +end + function test_objective_norm_univariate() model = Nonlinear.Model() x = MOI.VariableIndex(1) @@ -110,4 +165,4 @@ end end # module -TestArrayDiff.runtests() \ No newline at end of file +TestArrayDiff.runtests()