Skip to content

Commit e95d486

Browse files
kshyattlkdvos
andauthored
Support for Diagonal in orthnull (#144)
* Support orthnull for Diagonal * Update src/interface/orthnull.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/interface/orthnull.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent d707681 commit e95d486

4 files changed

Lines changed: 10 additions & 4 deletions

File tree

src/interface/orthnull.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,9 @@ left_orth_alg(alg::LeftOrthAlgorithm) = alg
443443
left_orth_alg(alg::QRAlgorithms) = LeftOrthViaQR(alg)
444444
left_orth_alg(alg::PolarAlgorithms) = LeftOrthViaPolar(alg)
445445
left_orth_alg(alg::SVDAlgorithms) = LeftOrthViaSVD(alg)
446+
left_orth_alg(alg::DiagonalAlgorithm) = LeftOrthViaQR(alg)
446447
left_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD(alg)
448+
left_orth_alg(alg::TruncatedAlgorithm{<:DiagonalAlgorithm}) = LeftOrthViaSVD(alg)
447449

448450
"""
449451
right_orth_alg(alg::AbstractAlgorithm) -> RightOrthAlgorithm
@@ -478,7 +480,9 @@ right_orth_alg(alg::RightOrthAlgorithm) = alg
478480
right_orth_alg(alg::LQAlgorithms) = RightOrthViaLQ(alg)
479481
right_orth_alg(alg::PolarAlgorithms) = RightOrthViaPolar(alg)
480482
right_orth_alg(alg::SVDAlgorithms) = RightOrthViaSVD(alg)
483+
right_orth_alg(alg::DiagonalAlgorithm) = RightOrthViaLQ(alg)
481484
right_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD(alg)
485+
right_orth_alg(alg::TruncatedAlgorithm{<:DiagonalAlgorithm}) = RightOrthViaSVD(alg)
482486

483487
"""
484488
left_null_alg(alg::AbstractAlgorithm) -> LeftNullAlgorithm

test/linearmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ module LinearMaps
3838
# Using AbstractAlgorithm here would be ambiguous since neither A-type nor alg-type would
3939
# be strictly more specific.
4040
for f! in (:svd_compact!, :svd_full!)
41-
for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDViaPolar)
41+
for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDViaPolar, :DiagonalAlgorithm)
4242
@eval MAK.$f!(A::LinearMap, USVᴴ, alg::MAK.$Alg) =
4343
LinearMap.(MAK.$f!(parent(A), parent.(USVᴴ), alg))
4444
end

test/orthnull.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 27)
1919
if T BLASFloats
2020
if CUDA.functional()
2121
TestSuite.test_orthnull(CuMatrix{T}, (m, n); test_nullity = false)
22-
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m; test_orthnull = false)
22+
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m)
2323
end
2424
if AMDGPU.functional()
2525
TestSuite.test_orthnull(ROCMatrix{T}, (m, n); test_nullity = false)
26-
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m; test_orthnull = false)
26+
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m)
2727
end
2828
end
2929
if !is_buildkite
3030
TestSuite.test_orthnull(T, (m, n))
3131
AT = Diagonal{T, Vector{T}}
32-
TestSuite.test_orthnull(AT, m; test_orthnull = false)
32+
TestSuite.test_orthnull(AT, m)
3333
end
3434
end

test/testsuite/TestSuite.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg)
7676
isleftcomplete(V, N) = V * V' + N * N' I
7777
isleftcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isleftcomplete(collect(V), collect(N))
7878
isleftcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isleftcomplete(collect(V), collect(N))
79+
isleftcomplete(V::Diagonal{TV, <:AnyROCVector}, N::AnyROCMatrix) where {TV} = isleftcomplete(Diagonal(collect(V.diag)), collect(N))
7980
isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ I
8081
isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N))
8182
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))
83+
isrightcomplete(V::Diagonal{TV, <:AnyROCVector}, N::AnyROCMatrix) where {TV} = isrightcomplete(Diagonal(collect(V.diag)), collect(N))
8284

8385
instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1]
8486
# AMDGPU can't generate ComplexF32 random numbers

0 commit comments

Comments
 (0)