Skip to content

Commit fa27db1

Browse files
committed
add chainrules support
1 parent 3f7d45d commit fa27db1

1 file changed

Lines changed: 57 additions & 22 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full)
188188
end
189189
return USVᴴ, svd_pullback
190190
end
191+
function ChainRulesCore.rrule(::typeof($svd_f), A, alg)
192+
USVᴴ = $(svd_f)(A, alg)
193+
function svd_pullback(ΔUSVᴴ)
194+
ΔA = zero(A)
195+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ))
196+
return NoTangent(), ΔA, NoTangent()
197+
end
198+
return USVᴴ, svd_pullback
199+
end
191200
end
192201
end
193202

@@ -196,55 +205,81 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
196205
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
197206
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
198207
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
199-
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
200-
end
201-
function _make_svd_trunc_pullback(A, USVᴴ, ind)
202208
function svd_trunc_pullback(ΔUSVᴴϵ)
203-
ΔA = zero(A)
204-
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
205-
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
206-
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
207-
end
208-
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
209+
ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind)
209210
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
210211
end
211-
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
212-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
212+
return (USVᴴ′..., ϵ), svd_trunc_pullback
213+
end
214+
function ChainRulesCore.rrule(::typeof(svd_trunc), A, alg::TruncatedAlgorithm)
215+
USVᴴ = svd_compact(A, alg.alg)
216+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
217+
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
218+
function svd_trunc_pullback(ΔUSVᴴϵ)
219+
ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind)
220+
return NoTangent(), ΔA, NoTangent()
213221
end
214-
return svd_trunc_pullback
222+
return (USVᴴ′..., ϵ), svd_trunc_pullback
215223
end
216224

217225
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm)
218226
Ac = copy_input(svd_compact, A)
219227
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
220228
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
221-
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
222-
end
223-
function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
224229
function svd_trunc_pullback(ΔUSVᴴ)
225-
ΔA = zero(A)
226-
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
227-
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
230+
ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind)
228231
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
229232
end
230-
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
231-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
233+
return USVᴴ′, svd_trunc_pullback
234+
end
235+
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error), A, alg::TruncatedAlgorithm)
236+
USVᴴ = svd_compact(A, alg.alg)
237+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
238+
function svd_trunc_pullback(ΔUSVᴴ)
239+
ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind)
240+
return NoTangent(), ΔA, NoTangent()
232241
end
233-
return svd_trunc_pullback
242+
return USVᴴ′, svd_trunc_pullback
234243
end
235244

245+
function _svd_trunc_pullback(ΔUSVᴴϵ, A, USVᴴ, ind)
246+
Δϵ = last(ΔUSVᴴϵ)
247+
!MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) &&
248+
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
249+
return _make_svd_trunc_no_error_pullback(Base.front(ΔUSVᴴ), A, USVᴴ, ind)
250+
end
251+
function _svd_trunc_no_error_pullback(ΔUSVᴴ, A, USVᴴ, ind)
252+
ΔA = zero(A)
253+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
254+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
255+
return ΔA
256+
end
257+
_svd_trunc_no_error_pullback(::NTuple{3, ZeroTangent}, A, USVᴴ, ind) = ZeroTangent()
258+
236259
function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
237260
USVᴴ = svd_compact(A, alg)
238261
function svd_vals_pullback(ΔS)
239262
ΔA = zero(A)
240263
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
241264
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
242265
end
243-
function svd_pullback(::ZeroTangent) # is this extra definition useful?
266+
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
244267
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
245268
end
246269
return diagview(USVᴴ[2]), svd_vals_pullback
247270
end
271+
function ChainRulesCore.rrule(::typeof(svd_vals), A, alg)
272+
USVᴴ = svd_compact(A, alg)
273+
function svd_vals_pullback(ΔS)
274+
ΔA = zero(A)
275+
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
276+
return NoTangent(), ΔA, NoTangent()
277+
end
278+
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
279+
return NoTangent(), ZeroTangent(), NoTangent()
280+
end
281+
return diagview(USVᴴ[2]), svd_vals_pullback
282+
end
248283

249284
function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
250285
Ac = copy_input(left_polar, A)

0 commit comments

Comments
 (0)