Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ ChainRulesCoreSparseArraysExt = "SparseArrays"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "Random", "SparseArrays", "StaticArrays"]

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
184 changes: 183 additions & 1 deletion ext/ChainRulesCoreSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,41 @@ module ChainRulesCoreSparseArraysExt

using ChainRulesCore
using ChainRulesCore: project_type, _projection_mismatch
using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals
using LinearAlgebra: Hermitian, Symmetric, tril, triu
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LinearAlgebra is not available in the SparseArrays extension. Only ChainRulesCore and SparseArrays are.

using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros

const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}}
const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}}
const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}}

const SparseProjectToData{T, I} = NamedTuple{
(:element, :axes, :rowval, :nzranges, :colptr),
Tuple{
ProjectTo{T, NamedTuple{(), Tuple{}}},
Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}},
Vector{I},
Vector{UnitRange{Int64}},
Vector{I},
},
}

const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}}

const HermSparseProjectTo{T, I} = ProjectTo{
Hermitian,
NamedTuple{
(:uplo, :parent),
Tuple{Symbol, SparseProjectTo{T, I}},
},
}

const SymSparseProjectTo{T, I} = ProjectTo{
Symmetric,
NamedTuple{
(:uplo, :parent),
Tuple{Symbol, SparseProjectTo{T, I}},
},
}

ChainRulesCore.is_inplaceable_destination(::SparseVector) = true
ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true
Expand Down Expand Up @@ -100,4 +134,152 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
end
end

#####
##### Hermitian/Symmetric sparse projection
#####

function project!(A::SparseMatrixCSC{T, I}, B::SparseMatrixCSC{<:Any, J}, uplo::Char) where {T, I, J}
@assert size(A) == size(B)

@inbounds for j in axes(A, 2)
p = getcolptr(A)[j]
pstop = getcolptr(A)[j + 1]
q = getcolptr(B)[j]
qstop = getcolptr(B)[j + 1]

while p < pstop
i = rowvals(A)[p]

if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j)
while q < qstop && rowvals(B)[q] < i
q += one(J)
end

if q < qstop && rowvals(B)[q] == i
nonzeros(A)[p] = nonzeros(B)[q]
else
nonzeros(A)[p] = zero(T)
end
end

p += one(I)
end
end

return A
end

function project!(A::HermOrSymSparse, B::HermOrSymSparse)
if A.uplo == B.uplo
project!(parent(A), parent(B), A.uplo)
elseif A.uplo == 'L'
project!(parent(A), tril(B), A.uplo)
else
project!(parent(A), triu(B), A.uplo)
end

return A
end

function sparse_from_project(P::SparseProjectTo{T, I}) where {T, I}
m, n = map(length, P.axes)
return SparseMatrixCSC(m, n, P.colptr, P.rowval, zeros(T, length(P.rowval)))
end

function sparse_from_project(P::HermSparseProjectTo)
return Hermitian(sparse_from_project(P.parent), P.uplo)
end

function sparse_from_project(P::SymSparseProjectTo)
return Symmetric(sparse_from_project(P.parent), P.uplo)
end

function checkpatternsym(n, Acolptr::Vector{IA}, Bcolptr::Vector{IB}, Arowval::AbstractVector, Browval::AbstractVector, uplo::Char) where {IA, IB}
for j in 1:n
pa = Acolptr[j]
pb = Bcolptr[j]
pastop = Acolptr[j + 1]
pbstop = Bcolptr[j + 1]

while pa < pastop && pb < pbstop
ia = Arowval[pa]
ib = Browval[pb]

if (uplo == 'L' && ia < j) || (uplo == 'U' && ia > j)
pa += one(IA)
elseif (uplo == 'L' && ib < j) || (uplo == 'U' && ib > j)
pb += one(IB)
elseif ia == ib
pa += one(IA)
pb += one(IB)
else
return false
end
end

while pa < pastop
ia = Arowval[pa]

if (uplo == 'L' && ia >= j) || (uplo == 'U' && ia <= j)
return false
end

pa += one(IA)
end

while pb < pbstop
ib = Browval[pb]

if (uplo == 'L' && ib >= j) || (uplo == 'U' && ib <= j)
return false
end

pb += one(IB)
end
end

return true
end

function checkpatternsym(P, dX)
return false
end

function checkpatternsym(P::Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}}, dX::HermOrSymSparse{T, I}) where {T, I}
dXP = parent(dX)
return Symbol(dX.uplo) == P.uplo && checkpatternsym(size(dXP, 2), P.parent.colptr, dXP.colptr, P.parent.rowval, dXP.rowval, dX.uplo)
end

function (P::HermSparseProjectTo{T, I})(dX::HermSparse) where {T, I}
if checkpatternsym(P, dX)
return dX
else
return project!(sparse_from_project(P), dX)
end
end

function (P::SymSparseProjectTo{T, I})(dX::SymSparse) where {T, I}
if checkpatternsym(P, dX)
return dX
else
return project!(sparse_from_project(P), dX)
end
end

function (P::HermSparseProjectTo{T, I})(dX::SymSparse{T, I}) where {T <: Real, I}
if checkpatternsym(P, dX)
return Hermitian(parent(dX), P.uplo)
else
return project!(sparse_from_project(P), dX)
end
end

function (P::SymSparseProjectTo{T, I})(dX::HermSparse{T, I}) where {T <: Real, I}
if checkpatternsym(P, dX)
return Symmetric(parent(dX), P.uplo)
else
return project!(sparse_from_project(P), dX)
end
end

end # module
57 changes: 57 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ChainRulesCore, Test
using LinearAlgebra, SparseArrays
using OffsetArrays, StaticArrays, BenchmarkTools
using Random: rand!

# Like ForwardDiff.jl's Dual
struct Dual{T<:Real} <: Real
Expand Down Expand Up @@ -355,6 +356,62 @@ struct NoSuperType end
@test_throws DimensionMismatch pm(ones(Int, 5, 20))
end

@testset "SparseArrays: Hermitian/Symmetric" begin
n = 100

function rand_sparse(SymHerm, T, n, uplo; density=0.3)
A = sprand(T, n, n, density)
if uplo == :U
return SymHerm(triu(A), uplo)
else
return SymHerm(tril(A), uplo)
end
end

function rand_tangent(A, uplo=Symbol(A.uplo))
dA = similar(A)
rand!(nonzeros(parent(dA)))
return typeof(A).name.wrapper(parent(dA), uplo)
end

function nzmatch(A, B)
I, J, _ = findnz(parent(A))
return all(A[i, j] == B[i, j] for (i, j) in zip(I, J))
end

@testset "$(SymHerm){$T}, uplo=:$uplo" for
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

A = rand_sparse(SymHerm, T, n, uplo)
P = ProjectTo(A)

# Same pattern
dA = rand_tangent(A)
@test P(dA) == dA

# Different uplo
other = uplo == :U ? :L : :U
dA2 = rand_tangent(A, other)
@test P(dA2) isa SymHerm{T, <:SparseMatrixCSC}
@test nzmatch(P(dA2), dA2)

# Different pattern
B = rand_sparse(SymHerm, T, n, uplo; density=0.5)
@test P(B) isa SymHerm{T, <:SparseMatrixCSC}
@test nzmatch(P(B), B)
end

@testset "Cross-type (real)" begin
AH = rand_sparse(Hermitian, Float64, n, :U)
AS = rand_sparse(Symmetric, Float64, n, :U)

@test ProjectTo(AH)(rand_tangent(AS)) isa Hermitian{Float64, <:SparseMatrixCSC}
@test ProjectTo(AS)(rand_tangent(AH)) isa Symmetric{Float64, <:SparseMatrixCSC}
end
end

#####
##### `OffsetArrays`
#####
Expand Down
Loading