Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/LinearOperatorFFTWExt/DCTOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ returns a `DCTOpImpl <: AbstractLinearOperator` which performs a DCT on a given
* `shape::Tuple` - size of the array to transform
* `dcttype` - type of DCT (currently `2` and `4` are supported)
"""
function LinearOperatorCollection.DCTOp(T::Type; shape::Tuple, S = Array{T}, dcttype=2)
function LinearOperatorCollection.DCTOp(T::Type; shape::Tuple, S = Vector{T}, dcttype=2)

tmp=similar(S(undef, 0), shape...)
if dcttype == 2
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearOperatorFFTWExt/DSTOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ returns a `LinearOperator` which performs a DST on a given input array.
* `T::Type` - type of the array to transform
* `shape::Tuple` - size of the array to transform
"""
function LinearOperatorCollection.DSTOp(T::Type; shape::Tuple, S = Array{T})
function LinearOperatorCollection.DSTOp(T::Type; shape::Tuple, S = Vector{T})
tmp=similar(S(undef, 0), shape...)

plan = FFTW.plan_r2r!(tmp,FFTW.RODFT10)
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearOperatorFFTWExt/FFTOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ returns an operator which performs an FFT on Arrays of type T
* (`S = Vector{T}`) - type of temporary vector, change to use on GPU
* (`kwargs...`) - keyword arguments given to fft plan
"""
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Vector{Complex{real(T)}}, kwargs...) where D

tmpVec = similar(S(undef, 0), shape...)
plan = plan_fft!(tmpVec; kwargs...)
Expand Down
8 changes: 4 additions & 4 deletions ext/LinearOperatorNFFTExt/NFFTOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end

LinearOperators.storage_type(::NFFTToeplitzNormalOp{T, vecT}) where {T, vecT} = vecT

function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::matT) where {T, D, matT <: AbstractArray{T, D}}
function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::matT, S = Vector{T}) where {T, D, matT <: AbstractArray{T, D}}

function produ!(y, shape, fftplan, ifftplan, λ, xL1, xL2, x)
xL1 .= 0
Expand All @@ -120,12 +120,12 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftpl
, (res,x) -> produ!(res, shape, fftplan, ifftplan, λ, xL1, xL2, x)
, nothing
, nothing
, 0, 0, 0, T[], T[]
, 0, 0, 0, S(undef, 0), S(undef, 0)
, shape, W, fftplan, ifftplan, λ, xL1, xL2)
end

# TODO: use vecT for toeplitz op
function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; kwargs...) where {T}
function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; S = LinearOperators.storage_type(nfft), kwargs...) where {T}
shape = size_in(nfft.plan)

tmpVec = similar(nfft.Mv, (2 .* shape)...)
Expand Down Expand Up @@ -156,7 +156,7 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothin
xL1 = tmpVec
xL2 = similar(xL1)

return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2)
return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2, S)
end

function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = nothing; copyOpsFn = copy, kwargs...) where T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LinearOperatorNonuniformFFTsExt

using LinearOperatorCollection, AbstractNFFTs, NonuniformFFTs, NonuniformFFTs.Kernels, FFTW

function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=nothing; kwargs...) where {T, P <: NonuniformFFTs.NFFTPlan}
function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=nothing; S = LinearOperators.storage_type(nfft), kwargs...) where {T, P <: NonuniformFFTs.NFFTPlan}
shape = size_in(nfft.plan)

tmpVec = similar(nfft.Mv, (2 .* shape)...)
Expand Down Expand Up @@ -39,7 +39,7 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=not
xL1 = tmpVec
xL2 = similar(xL1)

return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2)
return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2, S)
end


Expand Down
28 changes: 28 additions & 0 deletions test/testOperators.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
function check_storage(op, arrayType = AbstractArray)
S = LinearOperators.storage_type(op)
@test isconcretetype(S)
@test S <: AbstractVector
@test S <: arrayType
end

function testDCT1d(N=32;arrayType = Array)
Random.seed!(1235)
x = zeros(ComplexF64, N^2)
for i=1:5
x .+= rand()*cos.(rand(1:N^2)*collect(1:N^2)) .+ 1im*rand()*cos.(rand(1:N^2)*collect(1:N^2))
end
D1 = DCTOp(ComplexF64, shape=(N^2,), dcttype=2)
check_storage(D1, arrayType)
D2 = sqrt(2/N^2)*[cos(pi/(N^2)*j*(k+0.5)) for j=0:N^2-1,k=0:N^2-1]
D2[1,:] .*= 1/sqrt(2)
D3 = DCTOp(ComplexF64, shape=(N^2,), dcttype=4)
Expand Down Expand Up @@ -34,6 +42,7 @@ function testFFT1d(N=32,shift=true;arrayType = Array)
end
xop = arrayType(x)
D1 = FFTOp(ComplexF64, shape=(N^2,), shift=shift, S = typeof(ComplexF64.(xop)))
check_storage(D1, arrayType)
D2 = 1.0/N*[exp(-2*pi*im*j*k/N^2) for j=0:N^2-1,k=0:N^2-1]

y1 = Array(D1*xop)
Expand Down Expand Up @@ -61,6 +70,7 @@ function testFFT2d(N=32,shift=true;arrayType = Array)
end
xop = arrayType(x)
D1 = FFTOp(ComplexF64, shape=(N,N), shift=shift, S = typeof(ComplexF64.(xop)))
check_storage(D1, arrayType)

idx = CartesianIndices((N,N))[collect(1:N^2)]
D2 = 1.0/N*[ exp(-2*pi*im*((idx[j][1]-1)*(idx[k][1]-1)+(idx[j][2]-1)*(idx[k][2]-1))/N) for j=1:N^2, k=1:N^2 ]
Expand Down Expand Up @@ -89,6 +99,7 @@ function testWeighting(N=512;arrayType = Array)
x1 = rand(N)
weights = rand(N)
W = WeightingOp(arrayType(weights))
check_storage(W, arrayType)
y1 = W*arrayType(x1)
y = weights .* x1

Expand All @@ -106,6 +117,7 @@ function testGradOp1d(N=512;arrayType = Array)
x = rand(N)
xop = arrayType(x)
G = GradientOp(eltype(x); shape=size(x), S = typeof(xop))
check_storage(G, arrayType)
G0 = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]

y = Array(G*xop)
Expand All @@ -122,6 +134,7 @@ function testGradOp2d(N=64;arrayType = Array)
x = repeat(1:N,1,N)
xop = arrayType(vec(x))
G = GradientOp(eltype(x); shape=size(x), S = typeof(xop))
check_storage(G, arrayType)
G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]

y = Array(G*xop)
Expand All @@ -141,7 +154,9 @@ function testDirectionalGradOp(N=64;arrayType = Array)
x = rand(ComplexF64,N,N)
xop = arrayType(vec(x))
G1 = GradientOp(eltype(x); shape=size(x), dims=1, S = typeof(xop))
check_storage(G1, arrayType)
G2 = GradientOp(eltype(x); shape=size(x), dims=2, S = typeof(xop))
check_storage(G2, arrayType)
G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]

y1 = Array(G1*xop)
Expand Down Expand Up @@ -175,6 +190,7 @@ function testSampling(N=64;arrayType = Array)
# index-based sampling
idx = shuffle(collect(1:N^2)[1:N*div(N,2)])
SOp = SamplingOp(ComplexF64, pattern=idx, shape=(N,N), S = typeof(xop))
check_storage(SOp, arrayType)
y = Array(SOp*xop)
x2 = Array(adjoint(SOp)*arrayType(y))
# mask-based sampling
Expand All @@ -195,6 +211,7 @@ function testWavelet(M=64,N=60;arrayType = Array)
x = rand(M,N)
xop = arrayType(vec(x))
WOp = WaveletOp(Float64, shape=(M,N), S = typeof(xop))
check_storage(WOp, arrayType)
# TODO comparison against wavelet?
x_wavelet = Array(WOp*xop)
x_reco = reshape( adjoint(WOp)*x_wavelet, M, N)
Expand All @@ -219,6 +236,7 @@ function testNFFT2d(N=16;arrayType = Array)
xop = arrayType(vec(x))
nodes = [(idx[d] - N÷2 - 1)./N for d=1:2, idx in vec(CartesianIndices((N,N)))]
F_nfft = NFFTOp(ComplexF64; shape=(N,N), nodes, S = typeof(xop))
check_storage(F_nfft, arrayType)

# test against FourierOperators
y = vec( ifftshift(reshape(F*vec(fftshift(x)),N,N)) )
Expand All @@ -233,13 +251,15 @@ function testNFFT2d(N=16;arrayType = Array)
# test AHA w/o Toeplitz
F_nfft.toeplitz = false
AHA = normalOperator(F_nfft)
check_storage(AHA, arrayType)
y_AHA_nfft = Array(AHA * xop)
y_AHA = F' * F * vec(x)
@test y_AHA ≈ y_AHA_nfft rtol = 1e-2

# test AHA with Toeplitz
F_nfft.toeplitz = true
AHA = normalOperator(F_nfft)
check_storage(AHA, arrayType)
y_AHA_nfft_1 = Array(AHA * xop)
y_AHA_nfft_2 = Array(adjoint(F_nfft) * F_nfft * xop)
y_AHA = F' * F * vec(x)
Expand Down Expand Up @@ -274,6 +294,7 @@ function testNFFT3d(N=12;arrayType = Array)
xop = arrayType(vec(x))
nodes = [(idx[d] - N÷2 - 1)./N for d=1:3, idx in vec(CartesianIndices((N,N,N)))]
F_nfft = NFFTOp(ComplexF64; shape=(N,N,N), nodes=nodes, S = typeof(xop))
check_storage(F_nfft, arrayType)

# test agains FourierOperators
y = vec( ifftshift(reshape(F*vec(fftshift(x)),N,N,N)) )
Expand All @@ -288,13 +309,15 @@ function testNFFT3d(N=12;arrayType = Array)
# test AHA w/o Toeplitz
F_nfft.toeplitz = false
AHA = normalOperator(F_nfft)
check_storage(AHA, arrayType)
y_AHA_nfft = Array(AHA * xop)
y_AHA = F' * F * vec(x)
@test y_AHA ≈ y_AHA_nfft rtol = 1e-2

# test AHA with Toeplitz
F_nfft.toeplitz = true
AHA = normalOperator(F_nfft)
check_storage(AHA, arrayType)
y_AHA_nfft_1 = Array(AHA * xop)
y_AHA_nfft_2 = Array(adjoint(F_nfft) * F_nfft * xop)
y_AHA = F' * F * vec(x)
Expand All @@ -315,9 +338,13 @@ function testDiagOp(N=32,K=2;arrayType = Array,scheduler = DynamicScheduler())

blocks = [block for k = 1:K]
op1 = DiagOp(blocks; scheduler = scheduler)
check_storage(op1, arrayType)
op2 = DiagOp(blocks...; scheduler = scheduler)
check_storage(op2, arrayType)
op3 = DiagOp(block, K; scheduler = scheduler)
check_storage(op3, arrayType)
op4 = DiagOp(@view blocks[1:K]; scheduler = scheduler)
check_storage(op4, arrayType)


# Operations
Expand Down Expand Up @@ -406,6 +433,7 @@ function testRadonOp(N=32;arrayType = Array)
angles = collect(range(0, pi, 100))

op = RadonOp(eltype(x); shape = (N, N), angles, geometry = geom, S = typeof(xop))
check_storage(op, arrayType)

y = Array(radon(x, angles; geometry = geom))
y1 = reshape(Array(op * xop), size(y)...)
Expand Down
Loading