@@ -6,40 +6,26 @@ function check_qr_cotangents(
66 gauge_atol:: Real = default_pullback_gauge_atol (ΔQ)
77 )
88 minmn = min (size (Q, 1 ), size (R, 2 ))
9- if minmn > p # case where A is rank-deficient
10- Δgauge = abs (zero (eltype (Q)))
11- if ! iszerotangent (ΔQ)
12- # in this case the number Householder reflections will
13- # change upon small variations, and all of the remaining
14- # columns of ΔQ should be zero for a gauge-invariant
15- # cost function
16- ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q, 2 ))
17- Δgauge_Q = norm (ΔQ2, Inf )
18- Δgauge = max (Δgauge, Δgauge_Q)
19- end
20- if ! iszerotangent (ΔR)
21- ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): size (R, 2 ))
22- Δgauge_R = norm (ΔR22, Inf )
23- Δgauge = max (Δgauge, Δgauge_R)
24- end
25- Δgauge ≤ gauge_atol ||
26- @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
9+ Δgauge = abs (zero (eltype (Q)))
10+ if ! iszerotangent (ΔQ)
11+ ΔQ₂ = view (ΔQ, :, (p + 1 ): minmn)
12+ ΔQ₃ = ΔQ[:, (minmn + 1 ): size (Q, 2 )] # extra columns in the case of qr_full
13+ Δgauge_Q = norm (ΔQ₂, Inf )
14+ Q₁ = view (Q, :, 1 : p)
15+ Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
16+ mul! (ΔQ₃, Q₁, Q₁ᴴΔQ₃, - 1 , 1 )
17+ Δgauge_Q = max (Δgauge_Q, norm (ΔQ₃, Inf ))
18+ Δgauge = max (Δgauge, Δgauge_Q)
19+ end
20+ if ! iszerotangent (ΔR)
21+ ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): size (R, 2 ))
22+ Δgauge_R = norm (view (ΔR22, uppertriangularind (ΔR22)), Inf )
23+ Δgauge_R = max (Δgauge_R, norm (view (ΔR22, diagind (ΔR22)), Inf ))
24+ Δgauge = max (Δgauge, Δgauge_R)
2725 end
28- return
29- end
30-
31- function check_qr_full_cotangents (Q1, ΔQ2, Q1dΔQ2; gauge_atol:: Real = default_pullback_gauge_atol (ΔQ2))
32- # in the case where A is full rank, but there are more columns in Q than in A
33- # (the case of `qr_full`), there is gauge-invariant information in the
34- # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
35- # matrix. As the number of Householder reflections is in fixed in the full rank
36- # case, Q is expected to rotate smoothly (we might even be able to predict) also
37- # how the full Q2 will change, but this we omit for now, and we consider
38- # Q2' * ΔQ2 as a gauge dependent quantity.
39- Δgauge = norm (mul! (copy (ΔQ2), Q1, Q1dΔQ2, - 1 , 1 ), Inf )
4026 Δgauge ≤ gauge_atol ||
41- @warn " `qr_full ` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
42- return
27+ @warn " `qr ` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
28+ return nothing
4329end
4430
4531"""
@@ -69,53 +55,54 @@ function qr_pullback!(
6955 Q, R = QR
7056 m = size (Q, 1 )
7157 n = size (R, 2 )
58+ minmn = min (m, n)
7259 Rd = diagview (R)
7360 p = qr_rank (R; rank_atol)
7461
7562 ΔQ, ΔR = ΔQR
7663
77- Q1 = view (Q, :, 1 : p)
78- R11 = view (R, 1 : p, 1 : p)
79- ΔA1 = view (ΔA, :, 1 : p)
80- ΔA2 = view (ΔA, :, (p + 1 ): n)
64+ Q₁ = view (Q, :, 1 : p)
65+ R₁₁ = UpperTriangular ( view (R, 1 : p, 1 : p) )
66+ ΔA₁ = view (ΔA, :, 1 : p)
67+ ΔA₂ = view (ΔA, :, (p + 1 ): n)
8168
8269 check_qr_cotangents (Q, R, ΔQ, ΔR, p; gauge_atol)
8370
8471 ΔQ̃ = zero! (similar (Q, (m, p)))
8572 if ! iszerotangent (ΔQ)
86- copy! (ΔQ̃, view (ΔQ, :, 1 : p) )
87- if p < size (Q, 2 )
88- Q2 = view (Q, :, (p + 1 ) : size (Q, 2 ) )
89- ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q , 2 ))
90- Q1dΔQ2 = Q1 ' * ΔQ2
91- check_qr_full_cotangents (Q1, ΔQ2, Q1dΔQ2; gauge_atol)
92- ΔQ̃ = mul! (ΔQ̃, Q2, Q1dΔQ2 ' , - 1 , 1 )
73+ ΔQ₁ = view (ΔQ, :, 1 : p)
74+ copy! (ΔQ̃, ΔQ₁ )
75+ if minmn < size (Q, 2 )
76+ ΔQ₃ = view (ΔQ, :, (minmn + 1 ): size (ΔQ , 2 )) # extra columns in the case of qr_full
77+ Q₃ = view (Q, :, (minmn + 1 ) : size (Q, 2 ))
78+ Q₁ᴴΔQ₃ = Q₁ ' * ΔQ₃
79+ ΔQ̃ = mul! (ΔQ̃, Q₃, Q₁ᴴΔQ₃ ' , - 1 , 1 )
9380 end
9481 end
9582 if ! iszerotangent (ΔR) && n > p
96- R12 = view (R, 1 : p, (p + 1 ): n)
97- ΔR12 = view (ΔR, 1 : p, (p + 1 ): n)
98- ΔQ̃ = mul! (ΔQ̃, Q1, ΔR12 * R12 ' , - 1 , 1 )
99- # Adding ΔA2 contribution
100- ΔA2 = mul! (ΔA2, Q1, ΔR12 , 1 , 1 )
83+ R₁₂ = view (R, 1 : p, (p + 1 ): n)
84+ ΔR₁₂ = view (ΔR, 1 : p, (p + 1 ): n)
85+ ΔQ̃ = mul! (ΔQ̃, Q₁, ΔR₁₂ * R₁₂ ' , - 1 , 1 )
86+ # Adding ΔA₂ contribution
87+ ΔA₂ = mul! (ΔA₂, Q₁, ΔR₁₂ , 1 , 1 )
10188 end
10289
10390 # construct M
10491 M = zero! (similar (R, (p, p)))
10592 if ! iszerotangent (ΔR)
106- ΔR11 = view (ΔR, 1 : p, 1 : p)
107- M = mul! (M, ΔR11, R11 ' , 1 , 1 )
93+ ΔR₁₁ = UpperTriangular ( view (ΔR, 1 : p, 1 : p) )
94+ M = mul! (M, ΔR₁₁, R₁₁ ' , 1 , 1 )
10895 end
109- M = mul! (M, Q1 ' , ΔQ̃, - 1 , 1 )
96+ M = mul! (M, Q₁ ' , ΔQ̃, - 1 , 1 )
11097 view (M, lowertriangularind (M)) .= conj .(view (M, uppertriangularind (M)))
11198 if eltype (M) <: Complex
11299 Md = diagview (M)
113100 Md .= real .(Md)
114101 end
115- rdiv! (M, UpperTriangular (R11) ' )
116- rdiv! (ΔQ̃, UpperTriangular (R11) ' )
117- ΔA1 = mul! (ΔA1, Q1 , M, + 1 , 1 )
118- ΔA1 .+ = ΔQ̃
102+ rdiv! (M, R₁₁ ' ) # R₁₁ is upper triangular
103+ rdiv! (ΔQ̃, R₁₁ ' )
104+ ΔA₁ = mul! (ΔA₁, Q₁ , M, + 1 , 1 )
105+ ΔA₁ .+ = ΔQ̃
119106 return ΔA
120107end
121108
0 commit comments