Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,6 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]

[sources]
Strided = {url = "https://github.com/QuantumKitHub/Strided.jl/", rev = "ksh/copyto"}
7 changes: 5 additions & 2 deletions ext/TensorKitAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
data′ = adapt(to, x.data)
return DiagonalTensorMap(data′, x.domain)
end
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
return BraidingTensor(space(x), TensorKit.similarstoragetype(A, T), x.adjoint)
end
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {TA <: DenseArray{<:Number}, T, S, A}
return BraidingTensor(space(x), TA, x.adjoint)
end

end
20 changes: 19 additions & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ module TensorKitCUDAExt
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
using CUDA: @allowscalar
using cuTENSOR: cuTENSOR
using Strided: StridedViews
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!

using CUDA: KernelAbstractions
using CUDA.KernelAbstractions: @kernel, @index

using TensorKit
using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, similarmatrixtype, _set_subblock!

using TensorKit: MatrixAlgebraKit

Expand All @@ -19,4 +23,18 @@ using Random
include("cutensormap.jl")
include("truncation.jl")

TensorKit.similarmatrixtype(::Type{A}) where {T <: Number, M, A <: CuVector{T, M}} = CuMatrix{T, M}

function TensorKit._set_subblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
@kernel function fill_subblock_kernel!(subblock, val)
idx = @index(Global, Cartesian)
@inbounds subblock[idx[1], idx[2], idx[2], idx[1]] = val
end
kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data))
d1 = size(data, 1)
d2 = size(data, 2)
kernel(data, val; ndrange = (d1, d2))
return data
end

end
20 changes: 20 additions & 0 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,23 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
return tf
end
end


function TensorKit.add_kernel_nonthreaded!(
::TensorKit.FusionStyle,
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
)
# preallocate buffers
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)

for subtransformer in transformer.data
# Special case without intermediate buffers whenever there is only a single block
if length(subtransformer[1]) == 1
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
else
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
end
end
return nothing
end
13 changes: 13 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab
# default storage type for numbers
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}

@doc """
similarmatrixtype(T::Type{<:Number}) -> Matrix{T}
similarmatrixtype(A::Type{T, <:DenseVector{T}}) -> Matrix{T}

For a given dense vector type `A` or number type `T`, compute an appropriate
**matrix** storage type for tensors. This function is used internally for
[`BraidingTensor`](@ref) to determine the output storage format for indexing
and other operations with other tensor types.
""" similarmatrixtype

similarmatrixtype(::Type{T}) where {T <: Number} = Matrix{T}
similarmatrixtype(::Type{A}) where {T <: Number, A <: DenseVector{T}} = Matrix{T}

@doc """
promote_storagetype([T], A, B, C...)
promote_storagetype([T], TA, TB, TC...)
Expand Down
150 changes: 89 additions & 61 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,91 @@
# special (2,2) tensor that implements a standard braiding operation
#====================================================================#
"""
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}

Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
braids the first input over the second input; its inverse can be obtained as the adjoint.

It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
controls the array type of the braiding tensor used when indexing
and multiplying with other tensors.
"""
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
V1::S
V2::S
adjoint::Bool
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
for a in sectors(V1)
for b in sectors(V2)
for c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
end
function BraidingTensor{T, S, A}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
for a in sectors(V1), b in sectors(V2), c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
return new{T, S}(V1, V2, adjoint)
return new{T, S, A}(V1, V2, adjoint)
# partial construction: only construct rowr and colr when needed
end
end
function BraidingTensor{T, S}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A}
return BraidingTensor{T, S, A}(V1, V2, A, adjoint)
end
function BraidingTensor{T}(V1::S, V2::S, A, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, A, adjoint)
end
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, A, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
return BraidingTensor{T}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{A}, adjoint::Bool = false) where {T, A <: DenseVector{T}}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{T}, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
return BraidingTensor(promote(V1, V2)..., adjoint)
end
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {S <: IndexSpace, A <: AbstractArray}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
A′ = similarstoragetype(A, T)
return BraidingTensor{T, S}(V1, V2, A′, adjoint)
end
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], adjoint)
end
function BraidingTensor(V::HomSpace, ::Type{A}, adjoint::Bool = false) where {A}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], A, adjoint)
end
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T}(V[2], V[1], adjoint)
end
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
return BraidingTensor{T, S, A}(b.V1, b.V2, A, !b.adjoint)
end

storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2

# specializations to ignore the storagetype of BraidingTensor
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)

promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
similarstoragetype(B, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
similarstoragetype(A, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
similarstoragetype(A, T)
promote_storagetype(::Type{B}, ::Type{T}) where {B <: BraidingTensor, T <: AbstractTensorMap} =
promote_storagetype(storagetype(B), storagetype(T))
promote_storagetype(::Type{T}, ::Type{B}) where {B <: BraidingTensor, T <: AbstractTensorMap} =
promote_storagetype(storagetype(B), storagetype(T))
promote_storagetype(::Type{BA}, ::Type{BB}) where {BA <: BraidingTensor, BB <: BraidingTensor} =
promote_storagetype(storagetype(BA), storagetype(BB))

function Base.getindex(b::BraidingTensor)
sectortype(b) === Trivial || throw(SectorMismatch())
Expand Down Expand Up @@ -99,6 +118,14 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
return r
end

function _set_subblock!(data, val)
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = val
end
return data
end


@inline function subblock(
b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
) where {I <: Sector}
Expand All @@ -115,15 +142,12 @@ end
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
data_t = similarmatrixtype(storagetype(b))(undef, (n1, n2))
data = sreshape(StridedView(data_t), d)
fill!(data, zero(eltype(b)))

r = _braiding_factor(f₁, f₂, b.adjoint)
if !isnothing(r)
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = r
end
end
!isnothing(r) && _set_subblock!(data, r)
return data
end

Expand All @@ -134,8 +158,31 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

Base.complex(b::BraidingTensor{<:Complex}) = b
function Base.complex(b::BraidingTensor)
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
Ac = similarstoragetype(A, complex(T))
return BraidingTensor(space(b), Ac, b.adjoint)
end

function _trivial_subblock!(data, b::BraidingTensor)
V1, V2 = codomain(b)
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
_set_subblock!(subblock, one(eltype(b)))
return data
end

function _nontrivial_subblock!(data, b::BraidingTensor, s::Sector)
base_offset = first(blockstructure(b)[s][2]) - 1

for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
(f₁.coupled == f₂.coupled == s) || continue
r = _braiding_factor(f₁, f₂, b.adjoint)
isnothing(r) && continue
# change offset to account for single block
subblock = StridedView(data, sz, str, off - base_offset)
_set_subblock!(subblock, r)
end
return data
end

function block(b::BraidingTensor, s::Sector)
Expand All @@ -145,36 +192,17 @@ function block(b::BraidingTensor, s::Sector)
# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)
data = Matrix{eltype(b)}(undef, (m, n))

length(data) == 0 && return data # s ∉ blocksectors(b)
m * n == 0 && return similarmatrixtype(storagetype(b))(undef, (m, n)) # s ∉ blocksectors(b)

data = similarmatrixtype(storagetype(b))(undef, (m, n))
data = fill!(data, zero(eltype(b)))

V1, V2 = codomain(b)
if sectortype(b) === Trivial
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = one(eltype(b))
end
return data
end

base_offset = first(blockstructure(b)[s][2]) - 1

for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
(f₁.coupled == f₂.coupled == s) || continue
r = _braiding_factor(f₁, f₂, b.adjoint)
isnothing(r) && continue
# change offset to account for single block
subblock = StridedView(data, sz, str, off - base_offset)
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = r
end
return _trivial_subblock!(data, b)
else
return _nontrivial_subblock!(data, b, s)
end

return data
end

# Index manipulations
Expand Down
Loading
Loading