Skip to content

Commit 570b221

Browse files
Jutholkdvos
andauthored
Some changes to the ad test utils (#180)
* some changes to the ad test utils * some more changes / lq tests are failing for unknown reasons * one more attempt * update enzyme * unicode subscripts * remove debug code --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 2d40186 commit 570b221

File tree

24 files changed

+201
-369
lines changed

24 files changed

+201
-369
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Enzyme = "0.13.118"
3434
EnzymeTestUtils = "0.2.5"
3535
GenericLinearAlgebra = "0.3.19"
3636
GenericSchur = "0.5.6"
37-
JET = "0.9, 0.10"
37+
JET = "0.9, 0.10, 0.11"
3838
LinearAlgebra = "1"
3939
Mooncake = "0.5"
4040
ParallelTestRunner = "2"

src/common/view.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ diagonal(v::AbstractVector) = Diagonal(v)
2424
function lowertriangularind(A::AbstractMatrix)
2525
Base.require_one_based_indexing(A)
2626
m, n = size(A)
27-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
27+
minmn = min(m, n)
28+
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (m - minmn))
2829
offset = 0
2930
for j in 1:n
3031
r = (j + 1):m
@@ -37,7 +38,8 @@ end
3738
function uppertriangularind(A::AbstractMatrix)
3839
Base.require_one_based_indexing(A)
3940
m, n = size(A)
40-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
41+
minmn = min(m, n)
42+
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (n - minmn))
4143
offset = 0
4244
for i in 1:m
4345
r = (i + 1):n

src/pullbacks/lq.jl

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,26 @@ function check_lq_cotangents(
55
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
66
)
77
minmn = min(size(L, 1), size(Q, 2))
8-
if minmn > p # case where A is rank-deficient
9-
Δgauge = abs(zero(eltype(Q)))
10-
if !iszerotangent(ΔQ)
11-
# in this case the number Householder reflections will
12-
# change upon small variations, and all of the remaining
13-
# rows of ΔQ should be zero for a gauge-invariant
14-
# cost function
15-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
16-
Δgauge_Q = norm(ΔQ2, Inf)
17-
Δgauge = max(Δgauge, Δgauge_Q)
18-
end
19-
if !iszerotangent(ΔL)
20-
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
21-
Δgauge_L = norm(ΔL22, Inf)
22-
Δgauge = max(Δgauge, Δgauge_L)
23-
end
24-
Δgauge gauge_atol ||
25-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
8+
Δgauge = abs(zero(eltype(Q)))
9+
if !iszerotangent(ΔQ)
10+
ΔQ₂ = view(ΔQ, (p + 1):minmn, :)
11+
ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :]
12+
Δgauge_Q = norm(ΔQ₂, Inf)
13+
Q₁ = view(Q, 1:p, :)
14+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
15+
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
16+
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
17+
Δgauge = max(Δgauge, Δgauge_Q)
18+
end
19+
if !iszerotangent(ΔL)
20+
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
21+
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
22+
Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf))
23+
Δgauge = max(Δgauge, Δgauge_L)
2624
end
27-
return
28-
end
29-
30-
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
31-
# in the case where A is full rank, but there are more columns in Q than in A
32-
# (the case of `lq_full`), there is gauge-invariant information in the
33-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
34-
# matrix. As the number of Householder reflections is in fixed in the full rank
35-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
36-
# how the full Q2 will change, but this we omit for now, and we consider
37-
# Q2' * ΔQ2 as a gauge dependent quantity.
38-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
3925
Δgauge gauge_atol ||
40-
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
41-
return
26+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
27+
return nothing
4228
end
4329

4430
"""
@@ -67,54 +53,53 @@ function lq_pullback!(
6753
L, Q = LQ
6854
m = size(L, 1)
6955
n = size(Q, 2)
56+
minmn = min(m, n)
7057
p = lq_rank(L; rank_atol)
7158

7259
ΔL, ΔQ = ΔLQ
7360

74-
Q1 = view(Q, 1:p, :)
75-
Q2 = view(Q, (p + 1):size(Q, 1), :)
76-
L11 = view(L, 1:p, 1:p)
77-
ΔA1 = view(ΔA, 1:p, :)
78-
ΔA2 = view(ΔA, (p + 1):m, :)
61+
Q₁ = view(Q, 1:p, :)
62+
L₁₁ = LowerTriangular(view(L, 1:p, 1:p))
63+
ΔA₁ = view(ΔA, 1:p, :)
64+
ΔA₂ = view(ΔA, (p + 1):m, :)
7965

8066
check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
8167

8268
ΔQ̃ = zero!(similar(Q, (p, n)))
8369
if !iszerotangent(ΔQ)
84-
ΔQ1 = view(ΔQ, 1:p, :)
85-
copy!(ΔQ̃, ΔQ1)
86-
if p < size(Q, 1)
87-
Q2 = view(Q, (p + 1):size(Q, 1), :)
88-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
89-
ΔQ2Q1ᴴ = ΔQ2 * Q1'
90-
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
91-
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
70+
ΔQ₁ = view(ΔQ, 1:p, :)
71+
copy!(ΔQ̃, ΔQ₁)
72+
if minmn < size(Q, 1)
73+
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :)
74+
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
75+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
76+
ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1)
9277
end
9378
end
9479
if !iszerotangent(ΔL) && m > p
95-
L21 = view(L, (p + 1):m, 1:p)
96-
ΔL21 = view(ΔL, (p + 1):m, 1:p)
97-
ΔQ̃ = mul!(ΔQ̃, L21' * ΔL21, Q1, -1, 1)
98-
# Adding ΔA2 contribution
99-
ΔA2 = mul!(ΔA2, ΔL21, Q1, 1, 1)
80+
L₂₁ = view(L, (p + 1):m, 1:p)
81+
ΔL₂₁ = view(ΔL, (p + 1):m, 1:p)
82+
ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1)
83+
# Adding ΔA₂ contribution
84+
ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1)
10085
end
10186

10287
# construct M
10388
M = zero!(similar(L, (p, p)))
10489
if !iszerotangent(ΔL)
105-
ΔL11 = view(ΔL, 1:p, 1:p)
106-
M = mul!(M, L11', ΔL11, 1, 1)
90+
ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p))
91+
M = mul!(M, L₁₁', ΔL₁₁, 1, 1)
10792
end
108-
M = mul!(M, ΔQ̃, Q1', -1, 1)
93+
M = mul!(M, ΔQ̃, Q₁', -1, 1)
10994
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
11095
if eltype(M) <: Complex
11196
Md = diagview(M)
11297
Md .= real.(Md)
11398
end
114-
ldiv!(LowerTriangular(L11)', M)
115-
ldiv!(LowerTriangular(L11)', ΔQ̃)
116-
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
117-
ΔA1 .+= ΔQ̃
99+
ldiv!(L₁₁', M)
100+
ldiv!(L₁₁', ΔQ̃)
101+
ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1)
102+
ΔA₁ .+= ΔQ̃
118103
return ΔA
119104
end
120105

src/pullbacks/qr.jl

Lines changed: 42 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,26 @@ function check_qr_cotangents(
66
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
77
)
88
minmn = min(size(Q, 1), size(R, 2))
9-
if minmn > p # case where A is rank-deficient
10-
Δgauge = abs(zero(eltype(Q)))
11-
if !iszerotangent(ΔQ)
12-
# in this case the number Householder reflections will
13-
# change upon small variations, and all of the remaining
14-
# columns of ΔQ should be zero for a gauge-invariant
15-
# cost function
16-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
17-
Δgauge_Q = norm(ΔQ2, Inf)
18-
Δgauge = max(Δgauge, Δgauge_Q)
19-
end
20-
if !iszerotangent(ΔR)
21-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
22-
Δgauge_R = norm(ΔR22, Inf)
23-
Δgauge = max(Δgauge, Δgauge_R)
24-
end
25-
Δgauge gauge_atol ||
26-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
9+
Δgauge = abs(zero(eltype(Q)))
10+
if !iszerotangent(ΔQ)
11+
ΔQ₂ = view(ΔQ, :, (p + 1):minmn)
12+
ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full
13+
Δgauge_Q = norm(ΔQ₂, Inf)
14+
Q₁ = view(Q, :, 1:p)
15+
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
16+
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1)
17+
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
18+
Δgauge = max(Δgauge, Δgauge_Q)
19+
end
20+
if !iszerotangent(ΔR)
21+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
22+
Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf)
23+
Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf))
24+
Δgauge = max(Δgauge, Δgauge_R)
2725
end
28-
return
29-
end
30-
31-
function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
32-
# in the case where A is full rank, but there are more columns in Q than in A
33-
# (the case of `qr_full`), there is gauge-invariant information in the
34-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
35-
# matrix. As the number of Householder reflections is in fixed in the full rank
36-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
37-
# how the full Q2 will change, but this we omit for now, and we consider
38-
# Q2' * ΔQ2 as a gauge dependent quantity.
39-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
4026
Δgauge gauge_atol ||
41-
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
42-
return
27+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
28+
return nothing
4329
end
4430

4531
"""
@@ -69,53 +55,54 @@ function qr_pullback!(
6955
Q, R = QR
7056
m = size(Q, 1)
7157
n = size(R, 2)
58+
minmn = min(m, n)
7259
Rd = diagview(R)
7360
p = qr_rank(R; rank_atol)
7461

7562
ΔQ, ΔR = ΔQR
7663

77-
Q1 = view(Q, :, 1:p)
78-
R11 = view(R, 1:p, 1:p)
79-
ΔA1 = view(ΔA, :, 1:p)
80-
ΔA2 = view(ΔA, :, (p + 1):n)
64+
Q₁ = view(Q, :, 1:p)
65+
R₁₁ = UpperTriangular(view(R, 1:p, 1:p))
66+
ΔA₁ = view(ΔA, :, 1:p)
67+
ΔA₂ = view(ΔA, :, (p + 1):n)
8168

8269
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)
8370

8471
ΔQ̃ = zero!(similar(Q, (m, p)))
8572
if !iszerotangent(ΔQ)
86-
copy!(ΔQ̃, view(ΔQ, :, 1:p))
87-
if p < size(Q, 2)
88-
Q2 = view(Q, :, (p + 1):size(Q, 2))
89-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
90-
Q1dΔQ2 = Q1' * ΔQ2
91-
check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol)
92-
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
73+
ΔQ₁ = view(ΔQ, :, 1:p)
74+
copy!(ΔQ̃, ΔQ₁)
75+
if minmn < size(Q, 2)
76+
ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full
77+
Q₃ = view(Q, :, (minmn + 1):size(Q, 2))
78+
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
79+
ΔQ̃ = mul!(ΔQ̃, Q₃, Q₁ᴴΔQ₃', -1, 1)
9380
end
9481
end
9582
if !iszerotangent(ΔR) && n > p
96-
R12 = view(R, 1:p, (p + 1):n)
97-
ΔR12 = view(ΔR, 1:p, (p + 1):n)
98-
ΔQ̃ = mul!(ΔQ̃, Q1, ΔR12 * R12', -1, 1)
99-
# Adding ΔA2 contribution
100-
ΔA2 = mul!(ΔA2, Q1, ΔR12, 1, 1)
83+
R₁₂ = view(R, 1:p, (p + 1):n)
84+
ΔR₁₂ = view(ΔR, 1:p, (p + 1):n)
85+
ΔQ̃ = mul!(ΔQ̃, Q₁, ΔR₁₂ * R₁₂', -1, 1)
86+
# Adding ΔA₂ contribution
87+
ΔA₂ = mul!(ΔA₂, Q₁, ΔR₁₂, 1, 1)
10188
end
10289

10390
# construct M
10491
M = zero!(similar(R, (p, p)))
10592
if !iszerotangent(ΔR)
106-
ΔR11 = view(ΔR, 1:p, 1:p)
107-
M = mul!(M, ΔR11, R11', 1, 1)
93+
ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p))
94+
M = mul!(M, ΔR₁₁, R₁₁', 1, 1)
10895
end
109-
M = mul!(M, Q1', ΔQ̃, -1, 1)
96+
M = mul!(M, Q₁', ΔQ̃, -1, 1)
11097
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
11198
if eltype(M) <: Complex
11299
Md = diagview(M)
113100
Md .= real.(Md)
114101
end
115-
rdiv!(M, UpperTriangular(R11)')
116-
rdiv!(ΔQ̃, UpperTriangular(R11)')
117-
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
118-
ΔA1 .+= ΔQ̃
102+
rdiv!(M, R₁₁') # R₁₁ is upper triangular
103+
rdiv!(ΔQ̃, R₁₁')
104+
ΔA₁ = mul!(ΔA₁, Q₁, M, +1, 1)
105+
ΔA₁ .+= ΔQ̃
119106
return ΔA
120107
end
121108

src/pullbacks/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true)
1+
svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true)
22

33
function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV))
44
mask = abs.(Sr' .- Sr) .< degeneracy_atol
@@ -43,7 +43,7 @@ function svd_pullback!(
4343
minmn = min(m, n)
4444
S = diagview(Smat)
4545
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
46-
r = svd_rank(S, rank_atol)
46+
r = svd_rank(S; rank_atol)
4747
Ur = view(U, :, 1:r)
4848
Vᴴr = view(Vᴴ, 1:r, :)
4949
Sr = view(S, 1:r)

test/enzyme/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1414

1515
m = 19
1616
for T in (BLASFloats..., GenericFloats...)
17-
TestSuite.seed_rng!(123)
17+
TestSuite.seed_rng!(1234)
1818
if !is_buildkite
1919
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
2020
end

test/enzyme/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1414

1515
m = 19
1616
for T in (BLASFloats..., GenericFloats...)
17-
TestSuite.seed_rng!(123)
17+
TestSuite.seed_rng!(1234)
1818
if !is_buildkite
1919
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
2020
end

test/enzyme/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/enzyme/orthnull.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/enzyme/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

0 commit comments

Comments
 (0)