@@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full)
188188 end
189189 return USVᴴ, svd_pullback
190190 end
191+ function ChainRulesCore. rrule (:: typeof ($ svd_f), A, alg)
192+ USVᴴ = $ (svd_f)(A, alg)
193+ function svd_pullback (ΔUSVᴴ)
194+ ΔA = zero (A)
195+ MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .(ΔUSVᴴ))
196+ return NoTangent (), ΔA, NoTangent ()
197+ end
198+ return USVᴴ, svd_pullback
199+ end
191200 end
192201end
193202
@@ -196,55 +205,81 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
196205 USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
197206 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
198207 ϵ = truncation_error (diagview (USVᴴ[2 ]), ind)
199- return (USVᴴ′... , ϵ), _make_svd_trunc_pullback (A, USVᴴ, ind)
200- end
201- function _make_svd_trunc_pullback (A, USVᴴ, ind)
202208 function svd_trunc_pullback (ΔUSVᴴϵ)
203- ΔA = zero (A)
204- ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
205- if ! MatrixAlgebraKit. iszerotangent (Δϵ) && ! iszero (unthunk (Δϵ))
206- throw (ArgumentError (" Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" ))
207- end
208- MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .((ΔU, ΔS, ΔVᴴ)), ind)
209+ ΔA = _svd_trunc_pullback (unthunk (ΔUSVᴴϵ), A, USVᴴ, ind)
209210 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
210211 end
211- function svd_trunc_pullback (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
212- return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
212+ return (USVᴴ′... , ϵ), svd_trunc_pullback
213+ end
214+ function ChainRulesCore. rrule (:: typeof (svd_trunc), A, alg:: TruncatedAlgorithm )
215+ USVᴴ = svd_compact (A, alg. alg)
216+ USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
217+ ϵ = truncation_error (diagview (USVᴴ[2 ]), ind)
218+ function svd_trunc_pullback (ΔUSVᴴϵ)
219+ ΔA = _svd_trunc_pullback (unthunk (ΔUSVᴴϵ), A, USVᴴ, ind)
220+ return NoTangent (), ΔA, NoTangent ()
213221 end
214- return svd_trunc_pullback
222+ return (USVᴴ′ ... , ϵ), svd_trunc_pullback
215223end
216224
217225function ChainRulesCore. rrule (:: typeof (svd_trunc_no_error!), A, USVᴴ, alg:: TruncatedAlgorithm )
218226 Ac = copy_input (svd_compact, A)
219227 USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
220228 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
221- return USVᴴ′, _make_svd_trunc_no_error_pullback (A, USVᴴ, ind)
222- end
223- function _make_svd_trunc_no_error_pullback (A, USVᴴ, ind)
224229 function svd_trunc_pullback (ΔUSVᴴ)
225- ΔA = zero (A)
226- ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
227- MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .((ΔU, ΔS, ΔVᴴ)), ind)
230+ ΔA = _svd_trunc_no_error_pullback (unthunk (ΔUSVᴴ), A, USVᴴ, ind)
228231 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
229232 end
230- function svd_trunc_pullback (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
231- return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
233+ return USVᴴ′, svd_trunc_pullback
234+ end
235+ function ChainRulesCore. rrule (:: typeof (svd_trunc_no_error), A, alg:: TruncatedAlgorithm )
236+ USVᴴ = svd_compact (A, alg. alg)
237+ USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
238+ function svd_trunc_pullback (ΔUSVᴴ)
239+ ΔA = _svd_trunc_no_error_pullback (unthunk (ΔUSVᴴ), A, USVᴴ, ind)
240+ return NoTangent (), ΔA, NoTangent ()
232241 end
233- return svd_trunc_pullback
242+ return USVᴴ′, svd_trunc_pullback
234243end
235244
245+ function _svd_trunc_pullback (ΔUSVᴴϵ, A, USVᴴ, ind)
246+ Δϵ = last (ΔUSVᴴϵ)
247+ ! MatrixAlgebraKit. iszerotangent (Δϵ) && ! iszero (unthunk (Δϵ)) &&
248+ throw (ArgumentError (" Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" ))
249+ return _make_svd_trunc_no_error_pullback (Base. front (ΔUSVᴴ), A, USVᴴ, ind)
250+ end
251+ function _svd_trunc_no_error_pullback (ΔUSVᴴ, A, USVᴴ, ind)
252+ ΔA = zero (A)
253+ ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
254+ MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .((ΔU, ΔS, ΔVᴴ)), ind)
255+ return ΔA
256+ end
257+ _svd_trunc_no_error_pullback (:: NTuple{3, ZeroTangent} , A, USVᴴ, ind) = ZeroTangent ()
258+
236259function ChainRulesCore. rrule (:: typeof (svd_vals!), A, S, alg)
237260 USVᴴ = svd_compact (A, alg)
238261 function svd_vals_pullback (ΔS)
239262 ΔA = zero (A)
240263 MatrixAlgebraKit. svd_vals_pullback! (ΔA, A, USVᴴ, unthunk (ΔS))
241264 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
242265 end
243- function svd_pullback (:: ZeroTangent ) # is this extra definition useful?
266+ function svd_vals_pullback (:: ZeroTangent ) # is this extra definition useful?
244267 return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
245268 end
246269 return diagview (USVᴴ[2 ]), svd_vals_pullback
247270end
271+ function ChainRulesCore. rrule (:: typeof (svd_vals), A, alg)
272+ USVᴴ = svd_compact (A, alg)
273+ function svd_vals_pullback (ΔS)
274+ ΔA = zero (A)
275+ MatrixAlgebraKit. svd_vals_pullback! (ΔA, A, USVᴴ, unthunk (ΔS))
276+ return NoTangent (), ΔA, NoTangent ()
277+ end
278+ function svd_vals_pullback (:: ZeroTangent ) # is this extra definition useful?
279+ return NoTangent (), ZeroTangent (), NoTangent ()
280+ end
281+ return diagview (USVᴴ[2 ]), svd_vals_pullback
282+ end
248283
249284function ChainRulesCore. rrule (:: typeof (left_polar!), A, WP, alg)
250285 Ac = copy_input (left_polar, A)
0 commit comments