Skip to content

Commit ee479b3

Browse files
authored
centralize QR/LQ gaugefixing (#203)
1 parent e6e844f commit ee479b3

5 files changed

Lines changed: 46 additions & 80 deletions

File tree

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt
33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
55
using MatrixAlgebraKit: GLA, Driver
6-
import MatrixAlgebraKit: gesvd!, heev!
6+
import MatrixAlgebraKit: qr_householder!, qr_null_householder!, gesvd!, heev!
77
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
88
using LinearAlgebra: I, Diagonal, lmul!
99

@@ -53,7 +53,7 @@ function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix;
5353
return Dd, V
5454
end
5555

56-
function MatrixAlgebraKit.qr_householder!(
56+
function qr_householder!(
5757
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
5858
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
5959
)
@@ -68,36 +68,20 @@ function MatrixAlgebraKit.qr_householder!(
6868

6969
# compute QR
7070
Q̃, R̃ = qr!(A)
71-
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
72-
73-
if positive
74-
@inbounds for j in 1:minmn
75-
s = sign_safe(R̃[j, j])
76-
@simd for i in 1:m
77-
Q[i, j] *= s
78-
end
79-
end
80-
end
71+
lmul!(Q̃, one!(Q))
8172

8273
if computeR
83-
if positive
84-
@inbounds for j in n:-1:1
85-
@simd for i in 1:min(minmn, j)
86-
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
87-
end
88-
@simd for i in (min(minmn, j) + 1):size(R, 1)
89-
R[i, j] = zero(eltype(R))
90-
end
91-
end
92-
else
93-
R[1:minmn, :] .=
94-
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
95-
end
74+
copyto!(view(R, 1:minmn, :), R̃)
75+
zero!(view(R, (minmn + 1):size(R, 1), :))
76+
positive && gaugefix!(qr_householder!, Q, R, diagview(R̃))
77+
elseif positive
78+
gaugefix!(qr_householder!, Q, nothing, diagview(R̃))
9679
end
80+
9781
return Q, R
9882
end
9983

100-
function MatrixAlgebraKit.qr_null_householder!(
84+
function qr_null_householder!(
10185
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
10286
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
10387
)

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ include("interface/schur.jl")
9999
include("interface/polar.jl")
100100
include("interface/orthnull.jl")
101101

102-
include("common/gauge.jl") # needs to be defined after the functions are
103-
104102
include("implementations/projections.jl")
105103
include("implementations/truncation.jl")
106104
include("implementations/qr.jl")
@@ -113,6 +111,8 @@ include("implementations/schur.jl")
113111
include("implementations/polar.jl")
114112
include("implementations/orthnull.jl")
115113

114+
include("common/gauge.jl") # needs to be defined after the functions are
115+
116116
include("pullbacks/qr.jl")
117117
include("pullbacks/lq.jl")
118118
include("pullbacks/eig.jl")

src/common/gauge.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,28 @@ is real and positive.
1212
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
1313
_largest(x, y) = abs(x) < abs(y) ? y : x
1414

15+
function gaugefix!(::typeof(qr_householder!), Q, R, Rd)
16+
ax = Base.OneTo(length(Rd))
17+
Qf = view(Q, axes(Q, 1), ax)
18+
Qf .*= sign_safe.(transpose(Rd))
19+
if !isnothing(R)
20+
Rf = view(R, ax, axes(R, 2))
21+
Rf .*= conj.(sign_safe.(Rd))
22+
end
23+
return Q, R
24+
end
25+
26+
function gaugefix!(::typeof(lq_householder!), L, Q, Ld)
27+
ax = Base.OneTo(length(Ld))
28+
Qf = view(Q, ax, axes(Q, 2))
29+
Qf .*= sign_safe.(Ld)
30+
if !isnothing(L)
31+
Lf = view(L, axes(L, 1), ax)
32+
Lf .*= conj.(sign_safe.(transpose(Ld)))
33+
end
34+
return L, Q
35+
end
36+
1537
function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix)
1638
for j in axes(V, 2)
1739
v = view(V, :, j)

src/implementations/lq.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,14 @@ function lq_householder!(
163163
end
164164
end
165165

166-
if positive # already fix Q even if we do not need L
167-
@inbounds for j in 1:n
168-
@simd for i in 1:minmn
169-
s = sign_safe(A[i, i])
170-
Q[i, j] *= s
171-
end
172-
end
173-
end
174-
175166
if computeL
176167
= lowertriangular!(view(A, axes(L)...))
177-
if positive
178-
@inbounds for j in 1:minmn
179-
s = conj(sign_safe(L̃[j, j]))
180-
@simd for i in j:m
181-
L̃[i, j] = L̃[i, j] * s
182-
end
183-
end
184-
end
168+
positive && gaugefix!(lq_householder!, L̃, Q, diagview(A))
185169
copyto!(L, L̃)
170+
else
171+
gaugefix!(lq_householder!, nothing, Q, diagview(A))
186172
end
173+
187174
return L, Q
188175
end
189176
function lq_householder!(

src/implementations/qr.jl

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -174,43 +174,16 @@ function qr_householder!(
174174
end
175175
end
176176

177-
if positive # already fix Q even if we do not need R
178-
if driver === LAPACK()
179-
@inbounds for j in 1:minmn
180-
s = sign_safe(A[j, j])
181-
@simd for i in 1:m
182-
Q[i, j] *= s
183-
end
184-
end
185-
else
186-
# guaranteed τ exists and no longer needed
187-
τ .= sign_safe.(diagview(A))
188-
Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q
189-
Qf .= Qf .* transpose(τ)
190-
end
191-
end
192-
193177
if computeR
194-
= uppertriangular!(view(A, axes(R)...))
195-
if positive
196-
if driver === LAPACK()
197-
@inbounds for j in n:-1:1
198-
@simd for i in 1:min(minmn, j)
199-
R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
200-
end
201-
end
202-
else
203-
R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R
204-
R̃f .= conj.(τ) .* R̃f
205-
end
206-
end
207-
if !pivoted
208-
copyto!(R, R̃)
209-
else
210-
# probably very inefficient in terms of memory access
211-
copyto!(view(R, :, jpvt), R̃)
212-
end
178+
# we need to first copy then gaugefix - avoiding aliasing between R and Rd for broadcast
179+
Rd = diagview(A)
180+
Rf = pivoted ? view(R, :, jpvt) : R
181+
copyto!(Rf, uppertriangular!(view(A, axes(R)...)))
182+
positive && gaugefix!(qr_householder!, Q, Rf, Rd)
183+
elseif positive
184+
gaugefix!(qr_householder!, Q, nothing, diagview(A))
213185
end
186+
214187
return Q, R
215188
end
216189
function qr_householder!(

0 commit comments

Comments
 (0)