Skip to content

RFC: accept Tuple Tangent for arrays?#444

Closed
mcabbott wants to merge 1 commit intoJuliaDiff:mainfrom
mcabbott:tuplevec
Closed

RFC: accept Tuple Tangent for arrays?#444
mcabbott wants to merge 1 commit intoJuliaDiff:mainfrom
mcabbott:tuplevec

Conversation

@mcabbott
Copy link
Copy Markdown
Member

This does this:

julia> ProjectTo([1 2; 3 4])(Tangent{Tuple}(0.1,2,30,400))
2×2 Matrix{Float64}:
 0.1   30.0
 2.0  400.0

julia> ProjectTo([1,2,3,4]')(Tangent{Tuple}(0.1,2,30,400))
ERROR: DimensionMismatch("array with ndims(x) == 1 >  0 cannot have dx::Number")

as a step towards solving this:

julia> Zygote.gradient(x -> +(x...), [1 2; 3 4])[1]
(1, 1, 1, 1)

I'm not entirely sure this should be handled here not in Zygote, what thoughts?

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Aug 19, 2021

Codecov Report

❌ Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 92.02%. Comparing base (2208660) to head (abde997).
⚠️ Report is 314 commits behind head on main.

Files with missing lines Patch % Lines
src/projection.jl 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #444      +/-   ##
==========================================
- Coverage   92.37%   92.02%   -0.36%     
==========================================
  Files          14       14              
  Lines         787      790       +3     
==========================================
  Hits          727      727              
- Misses         60       63       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mzgubic mzgubic added ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values labels Aug 19, 2021
end

# Accept the Tangent corresponding to a Tuple -- Zygote's splats produce these
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple})
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i prefer this to be written as:

Suggested change
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple})
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Tuple})

so we don't need to mention how this backed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Zygote managed to produce a Tangent{Any, Tuple{...}}, which was why my pirate method specified that, here:

https://github.com/FluxML/Zygote.jl/pull/1044/files#diff-e0bc7da8f1a33a59f5ecfa67257c04038f0b4915b3f74bdf39780818fd0010a2R162

But I was squashing bugs as fast as I could & didn't track it down any further. Can circle back now that things basically pass.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, fair enough.
If we have to do it this way that is also fine.
Maybe just a commond is woth adding to clarify that this is the tangent for something with primal type tangent

Zygote does lose primal type information, that is a thing
We might well be able to teach it a bit more about tuples more easily than the general case.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, easy examples of this Tangent{Any} problem:

julia> using Zygote

julia> function Zygote.ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
         x, dx -> (Zygote.NoTangent(), @show dx)
       end

julia> Zygote.gradient(x -> max(identity(x)...), [1,2,3])
dx = Tangent{Any}(0, 0, 1)
((0, 0, 1),)

julia> Zygote.gradient(x -> sum(identity(x).parent), [1,2,3]')
dx = Tangent{Any}(parent = 3-element Fill{Int64}: entries equal to 1,)
((parent = 3-element Fill{Int64}: entries equal to 1,),)

@oxinabox
Copy link
Copy Markdown
Member

oxinabox commented Aug 19, 2021

I think I am ok with this.
Sometimes you end up with the wrong kind of iterator.
A common incorrect type of iterator is a tuple.

We might want to demand that it is only going into a AbstractVector.
But that might be too restictive since we might want to do 1 row matrixes etc also?

@mcabbott
Copy link
Copy Markdown
Member Author

Thanks for thinking. For Zygote at least I do think it wants to allow matrices, which also splat to tuples.

It's possible that this should be (project::ProjectTo{<:AbstractArray})(dx::Tangent{...}) to apply to splatted e.g. row vectors. But this might land us in dispatch hell, and might be best to explore after #430 .

Dispatch hell might be another reason to do this within Zygote rather than here. Diffractor also has difficulty wish splats, but not the same difficulty, so it's not obvious whether this is more widely useful.

@mcabbott mcabbott closed this Oct 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants