-
Notifications
You must be signed in to change notification settings - Fork 18
test_rrule() fails in iterate() #263
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
It looks like finite difference implementation has hard time going through iterate (see MRE and full stacktrace below):
juia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
...
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
...Below I provide rrule() implementation for iterate on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of _make_j′vp_call(). Also, I see the same error when testing with arrays.
MWE
using ChainRulesCore
import ChainRulesCore.rrule
using ChainRulesTestUtils
function ungetfield(dy, s::Tuple, f::Int)
T = typeof(s)
return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end
function rrule(::typeof(iterate), t::Tuple)
y = iterate(t)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, 1)
end
return y, iterate_pullback
end
function rrule(::typeof(iterate), t::Tuple, i::Integer)
y = iterate(t, i)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent()
end
return y, iterate_pullback
end
test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)Complete stacktrace
julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
[2] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
[3] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
[8] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined]
[9] macro expansion
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined]
[10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
[11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
[12] top-level scope
@ REPL[1]:1
[13] eval
@ ./boot.jl:368 [inlined]
[14] eval
@ ./Base.jl:65 [inlined]
[15] repleval(m::Module, code::Expr, #unused#::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222
[16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186
[17] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:511
[18] with_logger
@ ./logging.jl:623 [inlined]
[19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187
[20] #invokelatest#2
@ ./essentials.jl:729 [inlined]
[21] invokelatest(::Any)
@ Base ./essentials.jl:726
[22] macro expansion
@ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
[23] (::VSCodeServer.var"#61#62")()
@ VSCodeServer ./task.jl:484
Test Summary: | Pass Error Total Time
test_rrule: iterate on Float64,Float64 | 3 1 4 0.0s
ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working