From 7730590f14713fe8df139acd56b03cd19065b2d0 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 11:30:10 -0400 Subject: [PATCH] ProjectTo for symmetric sparse matrices. --- Project.toml | 3 +- ext/ChainRulesCoreSparseArraysExt.jl | 184 ++++++++++++++++++++++++++- test/projection.jl | 57 +++++++++ 3 files changed, 242 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 52ab392b1..d5ff22707 100644 --- a/Project.toml +++ b/Project.toml @@ -22,12 +22,13 @@ ChainRulesCoreSparseArraysExt = "SparseArrays" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"] +test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "Random", "SparseArrays", "StaticArrays"] [weakdeps] SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/ext/ChainRulesCoreSparseArraysExt.jl b/ext/ChainRulesCoreSparseArraysExt.jl index e4714f4c1..3e0c0757f 100644 --- a/ext/ChainRulesCoreSparseArraysExt.jl +++ b/ext/ChainRulesCoreSparseArraysExt.jl @@ -2,7 +2,41 @@ module ChainRulesCoreSparseArraysExt using ChainRulesCore using ChainRulesCore: project_type, _projection_mismatch -using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals +using LinearAlgebra: Hermitian, Symmetric, tril, triu +using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros + +const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}} +const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}} +const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}} + +const SparseProjectToData{T, I} = NamedTuple{ + (:element, :axes, :rowval, :nzranges, :colptr), + Tuple{ + ProjectTo{T, NamedTuple{(), Tuple{}}}, + Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, + Vector{I}, + Vector{UnitRange{Int64}}, + Vector{I}, + }, +} + +const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}} + +const HermSparseProjectTo{T, I} = ProjectTo{ + Hermitian, + NamedTuple{ + (:uplo, :parent), + Tuple{Symbol, SparseProjectTo{T, I}}, + }, +} + +const SymSparseProjectTo{T, I} = ProjectTo{ + Symmetric, + NamedTuple{ + (:uplo, :parent), + Tuple{Symbol, SparseProjectTo{T, I}}, + }, +} ChainRulesCore.is_inplaceable_destination(::SparseVector) = true ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true @@ -100,4 +134,152 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) end end +##### +##### Hermitian/Symmetric sparse projection +##### + +function project!(A::SparseMatrixCSC{T, I}, B::SparseMatrixCSC{<:Any, J}, uplo::Char) where {T, I, J} + @assert size(A) == size(B) + + @inbounds for j in axes(A, 2) + p = getcolptr(A)[j] + pstop = getcolptr(A)[j + 1] + q = getcolptr(B)[j] + qstop = getcolptr(B)[j + 1] + + while p < pstop + i = rowvals(A)[p] + + if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) + while q < qstop && rowvals(B)[q] < i + q += one(J) + end + + if q < qstop && rowvals(B)[q] == i + nonzeros(A)[p] = nonzeros(B)[q] + else + nonzeros(A)[p] = zero(T) + end + end + + p += one(I) + end + end + + return A +end + +function project!(A::HermOrSymSparse, B::HermOrSymSparse) + if A.uplo == B.uplo + project!(parent(A), parent(B), A.uplo) + elseif A.uplo == 'L' + project!(parent(A), tril(B), A.uplo) + else + project!(parent(A), triu(B), A.uplo) + end + + return A +end + +function sparse_from_project(P::SparseProjectTo{T, I}) where {T, I} + m, n = map(length, P.axes) + return SparseMatrixCSC(m, n, P.colptr, P.rowval, zeros(T, length(P.rowval))) +end + +function sparse_from_project(P::HermSparseProjectTo) + return Hermitian(sparse_from_project(P.parent), P.uplo) +end + +function sparse_from_project(P::SymSparseProjectTo) + return Symmetric(sparse_from_project(P.parent), P.uplo) +end + +function checkpatternsym(n, Acolptr::Vector{IA}, Bcolptr::Vector{IB}, Arowval::AbstractVector, Browval::AbstractVector, uplo::Char) where {IA, IB} + for j in 1:n + pa = Acolptr[j] + pb = Bcolptr[j] + pastop = Acolptr[j + 1] + pbstop = Bcolptr[j + 1] + + while pa < pastop && pb < pbstop + ia = Arowval[pa] + ib = Browval[pb] + + if (uplo == 'L' && ia < j) || (uplo == 'U' && ia > j) + pa += one(IA) + elseif (uplo == 'L' && ib < j) || (uplo == 'U' && ib > j) + pb += one(IB) + elseif ia == ib + pa += one(IA) + pb += one(IB) + else + return false + end + end + + while pa < pastop + ia = Arowval[pa] + + if (uplo == 'L' && ia >= j) || (uplo == 'U' && ia <= j) + return false + end + + pa += one(IA) + end + + while pb < pbstop + ib = Browval[pb] + + if (uplo == 'L' && ib >= j) || (uplo == 'U' && ib <= j) + return false + end + + pb += one(IB) + end + end + + return true +end + +function checkpatternsym(P, dX) + return false +end + +function checkpatternsym(P::Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}}, dX::HermOrSymSparse{T, I}) where {T, I} + dXP = parent(dX) + return Symbol(dX.uplo) == P.uplo && checkpatternsym(size(dXP, 2), P.parent.colptr, dXP.colptr, P.parent.rowval, dXP.rowval, dX.uplo) +end + +function (P::HermSparseProjectTo{T, I})(dX::HermSparse) where {T, I} + if checkpatternsym(P, dX) + return dX + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::SymSparseProjectTo{T, I})(dX::SymSparse) where {T, I} + if checkpatternsym(P, dX) + return dX + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::HermSparseProjectTo{T, I})(dX::SymSparse{T, I}) where {T <: Real, I} + if checkpatternsym(P, dX) + return Hermitian(parent(dX), P.uplo) + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::SymSparseProjectTo{T, I})(dX::HermSparse{T, I}) where {T <: Real, I} + if checkpatternsym(P, dX) + return Symmetric(parent(dX), P.uplo) + else + return project!(sparse_from_project(P), dX) + end +end + end # module diff --git a/test/projection.jl b/test/projection.jl index f0aaaa859..9cd294dbd 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,6 +1,7 @@ using ChainRulesCore, Test using LinearAlgebra, SparseArrays using OffsetArrays, StaticArrays, BenchmarkTools +using Random: rand! # Like ForwardDiff.jl's Dual struct Dual{T<:Real} <: Real @@ -355,6 +356,62 @@ struct NoSuperType end @test_throws DimensionMismatch pm(ones(Int, 5, 20)) end + @testset "SparseArrays: Hermitian/Symmetric" begin + n = 100 + + function rand_sparse(SymHerm, T, n, uplo; density=0.3) + A = sprand(T, n, n, density) + if uplo == :U + return SymHerm(triu(A), uplo) + else + return SymHerm(tril(A), uplo) + end + end + + function rand_tangent(A, uplo=Symbol(A.uplo)) + dA = similar(A) + rand!(nonzeros(parent(dA))) + return typeof(A).name.wrapper(parent(dA), uplo) + end + + function nzmatch(A, B) + I, J, _ = findnz(parent(A)) + return all(A[i, j] == B[i, j] for (i, j) in zip(I, J)) + end + + @testset "$(SymHerm){$T}, uplo=:$uplo" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + A = rand_sparse(SymHerm, T, n, uplo) + P = ProjectTo(A) + + # Same pattern + dA = rand_tangent(A) + @test P(dA) == dA + + # Different uplo + other = uplo == :U ? :L : :U + dA2 = rand_tangent(A, other) + @test P(dA2) isa SymHerm{T, <:SparseMatrixCSC} + @test nzmatch(P(dA2), dA2) + + # Different pattern + B = rand_sparse(SymHerm, T, n, uplo; density=0.5) + @test P(B) isa SymHerm{T, <:SparseMatrixCSC} + @test nzmatch(P(B), B) + end + + @testset "Cross-type (real)" begin + AH = rand_sparse(Hermitian, Float64, n, :U) + AS = rand_sparse(Symmetric, Float64, n, :U) + + @test ProjectTo(AH)(rand_tangent(AS)) isa Hermitian{Float64, <:SparseMatrixCSC} + @test ProjectTo(AS)(rand_tangent(AH)) isa Symmetric{Float64, <:SparseMatrixCSC} + end + end + ##### ##### `OffsetArrays` #####