Skip to content

Commit 28e60c0

Browse files
authored
Be more careful about maybe-inplace factorizations (#106)
1 parent 1dcc302 commit 28e60c0

File tree

4 files changed

+120
-79
lines changed

4 files changed

+120
-79
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.6.4"
4+
version = "0.6.5"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/MatrixAlgebra.jl

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,82 @@
11
module MatrixAlgebra
22

33
export eigen,
4-
eigen!,
4+
eigen!!,
55
eigvals,
6-
eigvals!,
6+
eigvals!!,
77
factorize,
8-
factorize!,
8+
factorize!!,
99
lq,
10-
lq!,
10+
lq!!,
1111
orth,
12-
orth!,
12+
orth!!,
1313
polar,
14-
polar!,
14+
polar!!,
1515
qr,
16-
qr!,
16+
qr!!,
1717
svd,
18-
svd!,
18+
svd!!,
1919
svdvals,
20-
svdvals!
20+
svdvals!!
2121

2222
using LinearAlgebra: LinearAlgebra, norm
23-
using MatrixAlgebraKit
23+
import MatrixAlgebraKit as MAK
2424

2525
for (f, f_full, f_compact) in (
2626
(:qr, :qr_full, :qr_compact),
27-
(:qr!, :qr_full!, :qr_compact!),
27+
(:qr!!, :qr_full!, :qr_compact!),
2828
(:lq, :lq_full, :lq_compact),
29-
(:lq!, :lq_full!, :lq_compact!),
29+
(:lq!!, :lq_full!, :lq_compact!),
3030
)
3131
@eval begin
3232
function $f(A::AbstractMatrix; full::Bool = false, kwargs...)
33-
f = full ? $f_full : $f_compact
34-
return f(A; kwargs...)
33+
return full ? MAK.$f_full(A; kwargs...) : MAK.$f_compact(A; kwargs...)
3534
end
3635
end
3736
end
3837

3938
for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in (
4039
(:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc),
41-
(:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
40+
(:eigen!!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
4241
)
4342
@eval begin
4443
function $eigen(A::AbstractMatrix; trunc = nothing, ishermitian = nothing, kwargs...)
4544
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
4645
return if !isnothing(trunc)
4746
if ishermitian
48-
$eigh_trunc(A; trunc, kwargs...)
47+
MAK.$eigh_trunc(A; trunc, kwargs...)
4948
else
50-
$eig_trunc(A; trunc, kwargs...)
49+
MAK.$eig_trunc(A; trunc, kwargs...)
5150
end
5251
else
5352
if ishermitian
54-
$eigh_full(A; kwargs...)
53+
MAK.$eigh_full(A; kwargs...)
5554
else
56-
$eig_full(A; kwargs...)
55+
MAK.$eig_full(A; kwargs...)
5756
end
5857
end
5958
end
6059
end
6160
end
6261

6362
for (eigvals, eigh_vals, eig_vals) in
64-
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!))
63+
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!!, :eigh_vals!, :eig_vals!))
6564
@eval begin
6665
function $eigvals(A::AbstractMatrix; ishermitian = nothing, kwargs...)
6766
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
68-
f = (ishermitian ? $eigh_vals : $eig_vals)
69-
return f(A; kwargs...)
67+
return ishermitian ? MAK.$eigh_vals(A; kwargs...) : MAK.$eig_vals(A; kwargs...)
7068
end
7169
end
7270
end
7371

7472
for (svd, svd_trunc, svd_full, svd_compact) in (
7573
(:svd, :svd_trunc, :svd_full, :svd_compact),
76-
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
74+
(:svd!!, :svd_trunc!, :svd_full!, :svd_compact!),
7775
)
7876
_svd = Symbol(:_, svd)
7977
@eval begin
8078
function $svd(
81-
A::AbstractMatrix;
82-
full::Union{Bool, Val} = Val(false),
83-
trunc = nothing,
79+
A::AbstractMatrix; full::Union{Bool, Val} = Val(false), trunc = nothing,
8480
kwargs...,
8581
)
8682
return $_svd(full, trunc, A; kwargs...)
@@ -89,13 +85,13 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
8985
return $_svd(Val(full), trunc, A; kwargs...)
9086
end
9187
function $_svd(full::Val{false}, trunc::Nothing, A::AbstractMatrix; kwargs...)
92-
return $svd_compact(A; kwargs...)
88+
return MAK.$svd_compact(A; kwargs...)
9389
end
9490
function $_svd(full::Val{false}, trunc, A::AbstractMatrix; kwargs...)
95-
return $svd_trunc(A; trunc, kwargs...)
91+
return MAK.$svd_trunc(A; trunc, kwargs...)
9692
end
9793
function $_svd(full::Val{true}, trunc::Nothing, A::AbstractMatrix; kwargs...)
98-
return $svd_full(A; kwargs...)
94+
return MAK.$svd_full(A; kwargs...)
9995
end
10096
function $_svd(full::Val{true}, trunc, A::AbstractMatrix; kwargs...)
10197
return throw(
@@ -107,55 +103,52 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
107103
end
108104
end
109105

110-
for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!))
106+
for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!!, :svd_vals!))
111107
@eval begin
112108
function $svdvals(A::AbstractMatrix; ishermitian = nothing, kwargs...)
113-
return $svd_vals(A; kwargs...)
109+
return MAK.$svd_vals(A; kwargs...)
114110
end
115111
end
116112
end
117113

118114
for (polar, left_polar, right_polar) in
119-
((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!))
115+
((:polar, :left_polar, :right_polar), (:polar!!, :left_polar!, :right_polar!))
120116
@eval begin
121117
function $polar(A::AbstractMatrix; side = :left, kwargs...)
122-
f = if side == :left
123-
$left_polar
118+
return if side == :left
119+
MAK.$left_polar(A; kwargs...)
124120
elseif side == :right
125-
$right_polar
121+
MAK.$right_polar(A; kwargs...)
126122
else
127-
throw(ArgumentError("`side=$side` not supported."))
123+
throw(ArgumentError("`side = $side` not supported."))
128124
end
129-
return f(A; kwargs...)
130125
end
131126
end
132127
end
133128

134129
for (orth, left_orth, right_orth) in
135-
((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!))
130+
((:orth, :left_orth, :right_orth), (:orth!!, :left_orth!, :right_orth!))
136131
@eval begin
137132
function $orth(A::AbstractMatrix; side = :left, kwargs...)
138-
f = if side == :left
139-
$left_orth
133+
return if side == :left
134+
MAK.$left_orth(A; kwargs...)
140135
elseif side == :right
141-
$right_orth
136+
MAK.$right_orth(A; kwargs...)
142137
else
143-
throw(ArgumentError("`side=$side` not supported."))
138+
throw(ArgumentError("`side = $side` not supported."))
144139
end
145-
return f(A; kwargs...)
146140
end
147141
end
148142
end
149143

150-
for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!))
144+
for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!!, :orth!!))
151145
@eval begin
152146
function $factorize(A::AbstractMatrix; orth = :left, kwargs...)
153-
f = if orth in (:left, :right)
154-
$orth_f
147+
return if orth in (:left, :right)
148+
$orth_f(A; side = orth, kwargs...)
155149
else
156-
throw(ArgumentError("`orth=$orth` not supported."))
150+
throw(ArgumentError("`orth = $orth` not supported."))
157151
end
158-
return f(A; side = orth, kwargs...)
159152
end
160153
end
161154
end
@@ -190,7 +183,6 @@ function truncdegen(strategy::TruncationStrategy; atol::Real = 0, rtol::Real = 0
190183
end
191184

192185
using MatrixAlgebraKit: findtruncated
193-
194186
function MatrixAlgebraKit.findtruncated(
195187
values::AbstractVector, strategy::TruncationDegenerate
196188
)

src/factorizations.jl

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
using LinearAlgebra: LinearAlgebra
22
using MatrixAlgebraKit: MatrixAlgebraKit
33

4-
for f in (
5-
:qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth,
6-
:factorize,
4+
for (f, f_mat) in (
5+
(:qr, :(MatrixAlgebra.qr)),
6+
(:lq, :(MatrixAlgebra.lq)),
7+
(:left_polar, :(MatrixAlgebraKit.left_polar)),
8+
(:right_polar, :(MatrixAlgebraKit.right_polar)),
9+
(:polar, :(MatrixAlgebra.polar)),
10+
(:left_orth, :(MatrixAlgebraKit.left_orth)),
11+
(:right_orth, :(MatrixAlgebraKit.right_orth)),
12+
(:orth, :(MatrixAlgebra.orth)),
13+
(:factorize, :(MatrixAlgebra.factorize)),
714
)
815
@eval begin
916
function $f(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
1017
A_mat = matricize(style, A, ndims_codomain)
11-
X, Y = MatrixAlgebra.$f(A_mat; kwargs...)
18+
X, Y = $f_mat(A_mat; kwargs...)
1219
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
1320
axes_codomain, axes_domain = blocks(axes(A)[biperm])
1421
axes_X = tuplemortar((axes_codomain, (axes(X, 2),)))
@@ -219,18 +226,25 @@ See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlg
219226
"""
220227
eigen
221228

222-
function eigen(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
229+
function eigen!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
223230
# tensor to matrix
224231
A_mat = matricize(style, A, ndims_codomain)
225-
D, V = MatrixAlgebra.eigen!(A_mat; kwargs...)
232+
D, V = MatrixAlgebra.eigen!!(A_mat; kwargs...)
226233
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
227234
axes_codomain, = blocks(axes(A)[biperm])
228235
axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),)))
229236
# TODO: Make sure `D` has the same basis as `V`.
230237
return D, unmatricize(style, V, axes_V)
231238
end
239+
function eigen!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
240+
return eigen!!(FusionStyle(A), A, ndims_codomain; kwargs...)
241+
end
242+
243+
function eigen(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
244+
return eigen!!(style, copy(A), ndims_codomain; kwargs...)
245+
end
232246
function eigen(A::AbstractArray, ndims_codomain::Val; kwargs...)
233-
return eigen(FusionStyle(A), A, ndims_codomain; kwargs...)
247+
return eigen!!(copy(A), ndims_codomain; kwargs...)
234248
end
235249

236250
"""
@@ -253,12 +267,19 @@ See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`.
253267
"""
254268
eigvals
255269

256-
function eigvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
270+
function eigvals!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
257271
A_mat = matricize(style, A, ndims_codomain)
258-
return MatrixAlgebra.eigvals!(A_mat; kwargs...)
272+
return MatrixAlgebra.eigvals!!(A_mat; kwargs...)
273+
end
274+
function eigvals!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
275+
return eigvals!!(FusionStyle(A), A, ndims_codomain; kwargs...)
276+
end
277+
278+
function eigvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
279+
return eigvals!!(style, copy(A), ndims_codomain; kwargs...)
259280
end
260281
function eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...)
261-
return eigvals(FusionStyle(A), A, ndims_codomain; kwargs...)
282+
return eigvals!!(copy(A), ndims_codomain; kwargs...)
262283
end
263284

264285
"""
@@ -282,17 +303,24 @@ See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `Mat
282303
"""
283304
svd
284305

285-
function svd(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
306+
function svd!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
286307
A_mat = matricize(style, A, ndims_codomain)
287-
U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...)
308+
U, S, Vᴴ = MatrixAlgebra.svd!!(A_mat; kwargs...)
288309
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
289310
axes_codomain, axes_domain = blocks(axes(A)[biperm])
290311
axes_U = tuplemortar((axes_codomain, (axes(U, 2),)))
291312
axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain))
292313
return unmatricize(style, U, axes_U), S, unmatricize(style, Vᴴ, axes_Vᴴ)
293314
end
315+
function svd!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
316+
return svd!!(FusionStyle(A), A, ndims_codomain; kwargs...)
317+
end
318+
319+
function svd(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
320+
return svd!!(style, copy(A), ndims_codomain; kwargs...)
321+
end
294322
function svd(A::AbstractArray, ndims_codomain::Val; kwargs...)
295-
return svd(FusionStyle(A), A, ndims_codomain; kwargs...)
323+
return svd!!(copy(A), ndims_codomain; kwargs...)
296324
end
297325

298326
"""
@@ -309,12 +337,19 @@ See also `MatrixAlgebraKit.svd_vals!`.
309337
"""
310338
svdvals
311339

312-
function svdvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
340+
function svdvals!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
313341
A_mat = matricize(style, A, ndims_codomain)
314-
return MatrixAlgebra.svdvals!(A_mat)
342+
return MatrixAlgebra.svdvals!!(A_mat)
343+
end
344+
function svdvals!!(A::AbstractArray, ndims_codomain::Val)
345+
return svdvals!!(FusionStyle(A), A, ndims_codomain)
346+
end
347+
348+
function svdvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
349+
return svdvals!!(style, copy(A), ndims_codomain)
315350
end
316351
function svdvals(A::AbstractArray, ndims_codomain::Val)
317-
return svdvals(FusionStyle(A), A, ndims_codomain)
352+
return svdvals!!(copy(A), ndims_codomain)
318353
end
319354

320355
"""
@@ -338,16 +373,23 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`.
338373
"""
339374
left_null
340375

341-
function left_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
376+
function left_null!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
342377
A_mat = matricize(style, A, ndims_codomain)
343378
N = MatrixAlgebraKit.left_null!(A_mat; kwargs...)
344379
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
345380
axes_codomain = first(blocks(axes(A)[biperm]))
346381
axes_N = tuplemortar((axes_codomain, (axes(N, 2),)))
347382
return unmatricize(style, N, axes_N)
348383
end
384+
function left_null!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
385+
return left_null!!(FusionStyle(A), A, ndims_codomain; kwargs...)
386+
end
387+
388+
function left_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
389+
return left_null!!(style, copy(A), ndims_codomain; kwargs...)
390+
end
349391
function left_null(A::AbstractArray, ndims_codomain::Val; kwargs...)
350-
return left_null(FusionStyle(A), A, ndims_codomain; kwargs...)
392+
return left_null!!(copy(A), ndims_codomain; kwargs...)
351393
end
352394

353395
"""
@@ -371,14 +413,21 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`.
371413
"""
372414
right_null
373415

374-
function right_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
416+
function right_null!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
375417
A_mat = matricize(style, A, ndims_codomain)
376418
Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...)
377419
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
378420
axes_domain = last(blocks((axes(A)[biperm])))
379421
axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain))
380422
return unmatricize(style, Nᴴ, axes_Nᴴ)
381423
end
424+
function right_null!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
425+
return right_null!!(FusionStyle(A), A, ndims_codomain; kwargs...)
426+
end
427+
428+
function right_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
429+
return right_null!!(style, copy(A), ndims_codomain; kwargs...)
430+
end
382431
function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...)
383-
return right_null(FusionStyle(A), A, ndims_codomain; kwargs...)
432+
return right_null!!(copy(A), ndims_codomain; kwargs...)
384433
end

0 commit comments

Comments
 (0)