Skip to content

Commit ea40a3a

Browse files
committed
start overloading ground state algoriths
1 parent 7468886 commit ea40a3a

File tree

15 files changed

+370
-61
lines changed

15 files changed

+370
-61
lines changed

LocalPreferences.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[MPIPreferences]
2+
__clear__ = ["libmpi", "abi", "mpiexec", "cclibs", "preloads_env_switch"]
3+
_format = "1.0"
4+
binary = "MPICH_jll"
5+
preloads = []

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,30 @@ version = "0.1.0"
44
authors = ["<Andreas Feuerpfeil|andreas.feuerpfeil@gmail.com>"]
55

66
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
9+
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
810
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
911
MPSKitModels = "ca635005-6f8c-4cd1-b51d-8491250ef2ab"
1012
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1113
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
14+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1215
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
1316
TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
17+
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1418

1519
[compat]
20+
LinearAlgebra = "1.12.0"
1621
MPI = "0.20.23"
22+
MPIPreferences = "0.1.11"
1723
MPSKit = "0.13.8"
1824
MPSKitModels = "0.4.4"
1925
MacroTools = "0.5.16"
2026
OhMyThreads = "0.8.3"
27+
Pkg = "1.12.0"
2128
TensorKit = "0.15.2"
2229
TensorKitManifolds = "0.7.3"
30+
VectorInterface = "0.5.0"
2331
julia = "1.6.7"
2432

2533
[extras]

examples/Ising.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Pkg
2+
Pkg.activate("/home/afeuerpfeil/.julia/dev/MPSKitParallel")
13
using MPSKit
24
using MPSKitModels
35
using TensorKit
@@ -10,12 +12,15 @@ H = heisenberg_XXX(symmetry, chain; J, spin);
1012
physical_space = SU2Space(1 => 1);
1113
virtual_space_inf = Rep[SU₂](1 // 2 => 16, 3 // 2 => 16, 5 // 2 => 8, 7 // 2 => 4);
1214
ψ₀_inf = InfiniteMPS([physical_space], [virtual_space_inf]);
13-
ψ_inf, envs_inf, delta_inf = find_groundstate(ψ₀_inf, H; verbosity = 3);
15+
ψ_inf, envs_inf, delta_inf = find_groundstate(ψ₀_inf, 2*H; verbosity = 3);
1416

1517

1618
using MPSKitParallel
19+
using MPSKitParallel: mpi_rank, mpi_size
1720
using MPI
1821
H_mpi = MPIOperator(H);
1922
MPI.Init()
23+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes.")
2024
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(ψ₀_inf, H_mpi; verbosity = 3);
21-
abs(dot(ψ_inf, ψ_infmpi))
25+
26+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")

examples/mpi_test.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using Pkg
2+
Pkg.activate("/home/afeuerpfeil/.julia/dev/MPSKitParallel")
3+
using LinearAlgebra
4+
5+
include("../src/multiprocessing/mpi/mpi_buffers.jl")
6+
using MPI
7+
MPI.Init()
8+
comm = MPI.COMM_WORLD
9+
rank=MPI.Comm_rank(comm)
10+
println("Hello world, I am $(MPI.Comm_rank(comm)) of $(MPI.Comm_size(comm))")
11+
MPI.Barrier(comm)
12+
13+
println("I am $(MPI.Comm_rank(comm)) of $(MPI.Comm_size(comm)), we are now testing the overloaded MPI functions")
14+
println("We begin with small data, so that no chunking is necessary:")
15+
A=rand(10,10)
16+
println("Rank $rank has matrix A with norm $(norm(A))")
17+
if rank!=0
18+
large_send(A, comm; dest=0, tag=0)
19+
else
20+
for i in 1:MPI.Comm_size(comm)-1
21+
A_received=large_receive(comm;source=i, tag=0)
22+
println("Rank 0 received matrix from rank $i with norm $(norm(A_received))")
23+
end
24+
end
25+
# println("Now we test big data")
26+
# A=rand(ComplexF64,16000,10000)
27+
# println("Rank $rank has matrix A with norm $(norm(A))")
28+
# if rank!=0
29+
# large_send(A, comm; dest=0, tag=0)
30+
# else
31+
# for i in 1:MPI.Comm_size(comm)-1
32+
# A_received=large_receive(comm;source=i, tag=0)
33+
# println("Rank 0 received matrix from rank $i with norm $(norm(A_received))")
34+
# end
35+
# end

src/MPIOperator/mpioperator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ end
3535
@forward MPIOperator.parent Base.getindex, Base.size, Base.length, Base.iterate, Base.eltype, Base.axes, Base.similar, Base.eachindex, Base.lastindex, Base.setindex!, Base.isfinite
3636
@forward MPIOperator.parent LinearAlgebra.norm
3737
@forward MPIOperator.parent TensorKit.spacetype, TensorKit.sectortype,TensorKit.storagetype
38-
@forward MPIOperator.parent MPSKit.eachsite, MPSKit.left_virtualspace, MPSKit.right_virtualspace, MPSKit.physicsalspace
38+
@forward MPIOperator.parent MPSKit.eachsite, MPSKit.left_virtualspace, MPSKit.right_virtualspace, MPSKit.physicalspace
3939
@forward_astype MPIOperator.parent MPSKit.remove_orphans!
40-
@forward_astype MPIOperator.parent Base.:+, Base.:-, Base.:*, Base.:/, Base.:\, Base.:(^), Base.conj!, Base.conj, Base.copy,
40+
@forward_astype MPIOperator.parent Base.:+, Base.:-, Base.:*, Base.:/, Base.:\, Base.:(^), Base.conj!, Base.conj, Base.copy
4141
@forward_1_1_astype MPIOperator.parent Base.:*
4242
@forward_astype MPIOperator.parent VectorInterface.scale
4343
@forward2 MPIOperator.parent MPSKit._fuse_mpo_mpo

src/MPSKitParallel.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,15 @@ using TensorKit
1313
using MPSKit
1414
using MPI
1515
using MacroTools
16+
using LinearAlgebra
17+
using VectorInterface
1618

19+
import LinearAlgebra: norm
20+
import VectorInterface: scale
1721
import MPSKit: environments, AbstractMPSEnvironments, InfiniteEnvironments
1822
import MPSKit: C_hamiltonian, AC_hamiltonian, AC2_hamiltonian, C_projection, AC_projection, AC2_projection
19-
import MPSKit: C_hamiltonian, AC_hamiltonian, AC2_hamiltonian, C_projection, AC_projection, AC2_projection
23+
import MPSKit: exact_diagonalization
24+
2025

2126
include("includes.jl")
2227

src/SharedMPS/sharedmps.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/algorithms/ED.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function exact_diagonalization(
1+
function MPSKit.exact_diagonalization(
22
H::MPIOperator{<:FiniteMPOHamiltonian};
33
sector = one(sectortype(H)), num::Int = 1, which::Symbol = :SR,
44
alg = Defaults.alg_eigsolve(; dynamic_tols = false)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
function MPSKit._localupdate_sweep_idmrg!::AbstractMPS, H::MPIOperator, envs, alg_eigsolve, ::IDMRG)
2+
C_old = ψ.C[0]
3+
# left to right sweep
4+
for pos in 1:length(ψ)
5+
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
6+
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
7+
if pos == length(ψ)
8+
# AC needed in next sweep
9+
ψ.AL[pos], ψ.C[pos] = mpi_left_orth.AC[pos])
10+
else
11+
ψ.AL[pos], ψ.C[pos] = mpi_left_orth!.AC[pos])
12+
end
13+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
14+
end
15+
16+
# right to left sweep
17+
for pos in length(ψ):-1:1
18+
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
19+
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
20+
21+
ψ.C[pos - 1], temp = mpi_right_orth!(_transpose_tail.AC[pos]; copy = (pos == 1)))
22+
ψ.AR[pos] = _transpose_front(temp)
23+
24+
transfer_rightenv!(envs, ψ, H, ψ, pos - 1)
25+
end
26+
return ψ, envs, C_old
27+
end
28+
29+
function MPSKit._localupdate_sweep_idmrg!::AbstractMPS, H::MPIOperator, envs, alg_eigsolve, alg::IDMRG2)
30+
# sweep from left to right
31+
for pos in 1:(length(ψ) - 1)
32+
ac2 = AC2(ψ, pos; kind = :ACAR)
33+
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
34+
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
35+
36+
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
37+
normalize!(c)
38+
39+
ψ.AL[pos] = al
40+
ψ.C[pos] = complex(c)
41+
ψ.AR[pos + 1] = _transpose_front(ar)
42+
ψ.AC[pos + 1] = _transpose_front(c * ar)
43+
44+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
45+
transfer_rightenv!(envs, ψ, H, ψ, pos)
46+
end
47+
48+
# update the edge
49+
ψ.AL[end] = ψ.AC[end] / ψ.C[end]
50+
ψ.AC[1] = _mul_tail.AL[1], ψ.C[1])
51+
ac2 = AC2(ψ, 0; kind = :ALAC)
52+
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
53+
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
54+
55+
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
56+
normalize!(c)
57+
58+
ψ.AL[end] = al
59+
ψ.C[end] = complex(c)
60+
ψ.AR[1] = _transpose_front(ar)
61+
62+
ψ.AC[end] = _mul_tail(al, c)
63+
ψ.AC[1] = _transpose_front(c * ar)
64+
ψ.AL[1] = ψ.AC[1] / ψ.C[1]
65+
66+
C_old = complex(c)
67+
68+
# update environments
69+
transfer_leftenv!(envs, ψ, H, ψ, 1)
70+
transfer_rightenv!(envs, ψ, H, ψ, 0)
71+
72+
# sweep from right to left
73+
for pos in (length(ψ) - 1):-1:1
74+
ac2 = AC2(ψ, pos; kind = :ALAC)
75+
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
76+
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
77+
78+
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
79+
normalize!(c)
80+
81+
ψ.AL[pos] = al
82+
ψ.AC[pos] = _mul_tail(al, c)
83+
ψ.C[pos] = complex(c)
84+
ψ.AR[pos + 1] = _transpose_front(ar)
85+
ψ.AC[pos + 1] = _transpose_front(c * ar)
86+
87+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
88+
transfer_rightenv!(envs, ψ, H, ψ, pos)
89+
end
90+
91+
# update the edge
92+
ψ.AC[end] = _mul_front.C[end - 1], ψ.AR[end])
93+
ψ.AR[1] = _transpose_front.C[end] \ _transpose_tail.AC[1]))
94+
ac2 = AC2(ψ, 0; kind = :ACAR)
95+
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
96+
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
97+
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
98+
normalize!(c)
99+
100+
ψ.AL[end] = al
101+
ψ.C[end] = complex(c)
102+
ψ.AR[1] = _transpose_front(ar)
103+
104+
ψ.AR[end] = _transpose_front.C[end - 1] \ _transpose_tail(al * c))
105+
ψ.AC[1] = _transpose_front(c * ar)
106+
107+
transfer_leftenv!(envs, ψ, H, ψ, 1)
108+
transfer_rightenv!(envs, ψ, H, ψ, 0)
109+
110+
return ψ, envs, C_old
111+
end
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
2+
function MPSKit.localupdate_step!(
3+
it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, scheduler = MPSKit.Defaults.scheduler[]
4+
) where {S, O, E}
5+
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
6+
alg_orth = MPSKit.Defaults.alg_qr()
7+
8+
mps = state.mps
9+
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
10+
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
11+
ACs = similar(mps.AC)
12+
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs
13+
14+
tforeach(eachsite(mps), src_ACs, src_Cs; scheduler) do site, AC₀, C₀
15+
dst_ACs[site] = MPSKit._localupdate_vumps_step!(
16+
site, mps, state.operator, state.envs, AC₀, C₀;
17+
parallel = false, alg_orth, state.which, alg_eigsolve
18+
)
19+
return nothing
20+
end
21+
22+
return ACs
23+
end
24+
25+
function MPSKit._localupdate_vumps_step!(
26+
site, mps, operator::MPIOperator, envs, AC₀, C₀;
27+
parallel::Bool = false, alg_orth = MPSKit.Defaults.alg_qr(),
28+
alg_eigsolve = MPSKit.Defaults.eigsolver, which
29+
)
30+
if !parallel
31+
Hac = AC_hamiltonian(site, mps, operator, mps, envs)
32+
_, AC = fixedpoint(Hac, AC₀, which, alg_eigsolve)
33+
Hc = C_hamiltonian(site, mps, operator, mps, envs)
34+
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
35+
return mpi_regauge!(AC, C; alg = alg_orth)
36+
end
37+
38+
local AC, C
39+
@sync begin
40+
@spawn begin
41+
Hac = AC_hamiltonian(site, mps, operator, mps, envs)
42+
_, AC = fixedpoint(Hac, AC₀, which, alg_eigsolve)
43+
end
44+
@spawn begin
45+
Hc = C_hamiltonian(site, mps, operator, mps, envs)
46+
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
47+
end
48+
end
49+
return mpi_regauge!(AC, C; alg = alg_orth)
50+
end
51+
52+
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, ACs::AbstractVector) where {S, O, E}
53+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
54+
if mpi_is_root()
55+
psi = InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
56+
else
57+
psi = nothing
58+
end
59+
psi = large_bcast(psi, 0, MPI.COMM_WORLD)
60+
return psi
61+
end
62+
63+
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, ACs::AbstractMatrix) where {S, O, E}
64+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
65+
if mpi_is_root()
66+
psi = MultilineMPS(ACs, @view(state.mps.C[:, end]); alg_gauge.tol, alg_gauge.maxiter)
67+
else
68+
psi = nothing
69+
end
70+
psi = large_bcast(psi, 0, MPI.COMM_WORLD)
71+
return psi
72+
end

0 commit comments

Comments
 (0)