diff --git a/ext/LinearOperatorFFTWExt/DCTOp.jl b/ext/LinearOperatorFFTWExt/DCTOp.jl index 7f2e321..13f8212 100644 --- a/ext/LinearOperatorFFTWExt/DCTOp.jl +++ b/ext/LinearOperatorFFTWExt/DCTOp.jl @@ -29,7 +29,7 @@ returns a `DCTOpImpl <: AbstractLinearOperator` which performs a DCT on a given * `shape::Tuple` - size of the array to transform * `dcttype` - type of DCT (currently `2` and `4` are supported) """ -function LinearOperatorCollection.DCTOp(T::Type; shape::Tuple, S = Array{T}, dcttype=2) +function LinearOperatorCollection.DCTOp(T::Type; shape::Tuple, S = Vector{T}, dcttype=2) tmp=similar(S(undef, 0), shape...) if dcttype == 2 diff --git a/ext/LinearOperatorFFTWExt/DSTOp.jl b/ext/LinearOperatorFFTWExt/DSTOp.jl index 14fb431..fccd948 100644 --- a/ext/LinearOperatorFFTWExt/DSTOp.jl +++ b/ext/LinearOperatorFFTWExt/DSTOp.jl @@ -28,7 +28,7 @@ returns a `LinearOperator` which performs a DST on a given input array. * `T::Type` - type of the array to transform * `shape::Tuple` - size of the array to transform """ -function LinearOperatorCollection.DSTOp(T::Type; shape::Tuple, S = Array{T}) +function LinearOperatorCollection.DSTOp(T::Type; shape::Tuple, S = Vector{T}) tmp=similar(S(undef, 0), shape...) plan = FFTW.plan_r2r!(tmp,FFTW.RODFT10) diff --git a/ext/LinearOperatorFFTWExt/FFTOp.jl b/ext/LinearOperatorFFTWExt/FFTOp.jl index 96946fc..840bc77 100644 --- a/ext/LinearOperatorFFTWExt/FFTOp.jl +++ b/ext/LinearOperatorFFTWExt/FFTOp.jl @@ -34,7 +34,7 @@ returns an operator which performs an FFT on Arrays of type T * (`S = Vector{T}`) - type of temporary vector, change to use on GPU * (`kwargs...`) - keyword arguments given to fft plan """ -function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D +function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Vector{Complex{real(T)}}, kwargs...) where D tmpVec = similar(S(undef, 0), shape...) plan = plan_fft!(tmpVec; kwargs...) diff --git a/ext/LinearOperatorNFFTExt/NFFTOp.jl b/ext/LinearOperatorNFFTExt/NFFTOp.jl index e382b8e..97f3917 100644 --- a/ext/LinearOperatorNFFTExt/NFFTOp.jl +++ b/ext/LinearOperatorNFFTExt/NFFTOp.jl @@ -101,7 +101,7 @@ end LinearOperators.storage_type(::NFFTToeplitzNormalOp{T, vecT}) where {T, vecT} = vecT -function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::matT) where {T, D, matT <: AbstractArray{T, D}} +function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::matT, S = Vector{T}) where {T, D, matT <: AbstractArray{T, D}} function produ!(y, shape, fftplan, ifftplan, λ, xL1, xL2, x) xL1 .= 0 @@ -120,12 +120,12 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftpl , (res,x) -> produ!(res, shape, fftplan, ifftplan, λ, xL1, xL2, x) , nothing , nothing - , 0, 0, 0, T[], T[] + , 0, 0, 0, S(undef, 0), S(undef, 0) , shape, W, fftplan, ifftplan, λ, xL1, xL2) end # TODO: use vecT for toeplitz op -function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; kwargs...) where {T} +function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; S = LinearOperators.storage_type(nfft), kwargs...) where {T} shape = size_in(nfft.plan) tmpVec = similar(nfft.Mv, (2 .* shape)...) @@ -156,7 +156,7 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothin xL1 = tmpVec xL2 = similar(xL1) - return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2) + return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2, S) end function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = nothing; copyOpsFn = copy, kwargs...) where T diff --git a/ext/LinearOperatorNonuniformFFTsExt/LinearOperatorNonuniformFFTsExt.jl b/ext/LinearOperatorNonuniformFFTsExt/LinearOperatorNonuniformFFTsExt.jl index b5f491e..cd33f6b 100644 --- a/ext/LinearOperatorNonuniformFFTsExt/LinearOperatorNonuniformFFTsExt.jl +++ b/ext/LinearOperatorNonuniformFFTsExt/LinearOperatorNonuniformFFTsExt.jl @@ -2,7 +2,7 @@ module LinearOperatorNonuniformFFTsExt using LinearOperatorCollection, AbstractNFFTs, NonuniformFFTs, NonuniformFFTs.Kernels, FFTW -function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=nothing; kwargs...) where {T, P <: NonuniformFFTs.NFFTPlan} +function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=nothing; S = LinearOperators.storage_type(nfft), kwargs...) where {T, P <: NonuniformFFTs.NFFTPlan} shape = size_in(nfft.plan) tmpVec = similar(nfft.Mv, (2 .* shape)...) @@ -39,7 +39,7 @@ function LinearOperatorCollection.NFFTToeplitzNormalOp(nfft::NFFTOp{T, P}, W=not xL1 = tmpVec xL2 = similar(xL1) - return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2) + return LinearOperatorCollection.NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2, S) end diff --git a/test/testOperators.jl b/test/testOperators.jl index 86bdff1..bd15ca8 100644 --- a/test/testOperators.jl +++ b/test/testOperators.jl @@ -1,3 +1,10 @@ +function check_storage(op, arrayType = AbstractArray) + S = LinearOperators.storage_type(op) + @test isconcretetype(S) + @test S <: AbstractVector + @test S <: arrayType +end + function testDCT1d(N=32;arrayType = Array) Random.seed!(1235) x = zeros(ComplexF64, N^2) @@ -5,6 +12,7 @@ function testDCT1d(N=32;arrayType = Array) x .+= rand()*cos.(rand(1:N^2)*collect(1:N^2)) .+ 1im*rand()*cos.(rand(1:N^2)*collect(1:N^2)) end D1 = DCTOp(ComplexF64, shape=(N^2,), dcttype=2) + check_storage(D1, arrayType) D2 = sqrt(2/N^2)*[cos(pi/(N^2)*j*(k+0.5)) for j=0:N^2-1,k=0:N^2-1] D2[1,:] .*= 1/sqrt(2) D3 = DCTOp(ComplexF64, shape=(N^2,), dcttype=4) @@ -34,6 +42,7 @@ function testFFT1d(N=32,shift=true;arrayType = Array) end xop = arrayType(x) D1 = FFTOp(ComplexF64, shape=(N^2,), shift=shift, S = typeof(ComplexF64.(xop))) + check_storage(D1, arrayType) D2 = 1.0/N*[exp(-2*pi*im*j*k/N^2) for j=0:N^2-1,k=0:N^2-1] y1 = Array(D1*xop) @@ -61,6 +70,7 @@ function testFFT2d(N=32,shift=true;arrayType = Array) end xop = arrayType(x) D1 = FFTOp(ComplexF64, shape=(N,N), shift=shift, S = typeof(ComplexF64.(xop))) + check_storage(D1, arrayType) idx = CartesianIndices((N,N))[collect(1:N^2)] D2 = 1.0/N*[ exp(-2*pi*im*((idx[j][1]-1)*(idx[k][1]-1)+(idx[j][2]-1)*(idx[k][2]-1))/N) for j=1:N^2, k=1:N^2 ] @@ -89,6 +99,7 @@ function testWeighting(N=512;arrayType = Array) x1 = rand(N) weights = rand(N) W = WeightingOp(arrayType(weights)) + check_storage(W, arrayType) y1 = W*arrayType(x1) y = weights .* x1 @@ -106,6 +117,7 @@ function testGradOp1d(N=512;arrayType = Array) x = rand(N) xop = arrayType(x) G = GradientOp(eltype(x); shape=size(x), S = typeof(xop)) + check_storage(G, arrayType) G0 = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:] y = Array(G*xop) @@ -122,6 +134,7 @@ function testGradOp2d(N=64;arrayType = Array) x = repeat(1:N,1,N) xop = arrayType(vec(x)) G = GradientOp(eltype(x); shape=size(x), S = typeof(xop)) + check_storage(G, arrayType) G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:] y = Array(G*xop) @@ -141,7 +154,9 @@ function testDirectionalGradOp(N=64;arrayType = Array) x = rand(ComplexF64,N,N) xop = arrayType(vec(x)) G1 = GradientOp(eltype(x); shape=size(x), dims=1, S = typeof(xop)) + check_storage(G1, arrayType) G2 = GradientOp(eltype(x); shape=size(x), dims=2, S = typeof(xop)) + check_storage(G2, arrayType) G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:] y1 = Array(G1*xop) @@ -175,6 +190,7 @@ function testSampling(N=64;arrayType = Array) # index-based sampling idx = shuffle(collect(1:N^2)[1:N*div(N,2)]) SOp = SamplingOp(ComplexF64, pattern=idx, shape=(N,N), S = typeof(xop)) + check_storage(SOp, arrayType) y = Array(SOp*xop) x2 = Array(adjoint(SOp)*arrayType(y)) # mask-based sampling @@ -195,6 +211,7 @@ function testWavelet(M=64,N=60;arrayType = Array) x = rand(M,N) xop = arrayType(vec(x)) WOp = WaveletOp(Float64, shape=(M,N), S = typeof(xop)) + check_storage(WOp, arrayType) # TODO comparison against wavelet? x_wavelet = Array(WOp*xop) x_reco = reshape( adjoint(WOp)*x_wavelet, M, N) @@ -219,6 +236,7 @@ function testNFFT2d(N=16;arrayType = Array) xop = arrayType(vec(x)) nodes = [(idx[d] - N÷2 - 1)./N for d=1:2, idx in vec(CartesianIndices((N,N)))] F_nfft = NFFTOp(ComplexF64; shape=(N,N), nodes, S = typeof(xop)) + check_storage(F_nfft, arrayType) # test against FourierOperators y = vec( ifftshift(reshape(F*vec(fftshift(x)),N,N)) ) @@ -233,6 +251,7 @@ function testNFFT2d(N=16;arrayType = Array) # test AHA w/o Toeplitz F_nfft.toeplitz = false AHA = normalOperator(F_nfft) + check_storage(AHA, arrayType) y_AHA_nfft = Array(AHA * xop) y_AHA = F' * F * vec(x) @test y_AHA ≈ y_AHA_nfft rtol = 1e-2 @@ -240,6 +259,7 @@ function testNFFT2d(N=16;arrayType = Array) # test AHA with Toeplitz F_nfft.toeplitz = true AHA = normalOperator(F_nfft) + check_storage(AHA, arrayType) y_AHA_nfft_1 = Array(AHA * xop) y_AHA_nfft_2 = Array(adjoint(F_nfft) * F_nfft * xop) y_AHA = F' * F * vec(x) @@ -274,6 +294,7 @@ function testNFFT3d(N=12;arrayType = Array) xop = arrayType(vec(x)) nodes = [(idx[d] - N÷2 - 1)./N for d=1:3, idx in vec(CartesianIndices((N,N,N)))] F_nfft = NFFTOp(ComplexF64; shape=(N,N,N), nodes=nodes, S = typeof(xop)) + check_storage(F_nfft, arrayType) # test agains FourierOperators y = vec( ifftshift(reshape(F*vec(fftshift(x)),N,N,N)) ) @@ -288,6 +309,7 @@ function testNFFT3d(N=12;arrayType = Array) # test AHA w/o Toeplitz F_nfft.toeplitz = false AHA = normalOperator(F_nfft) + check_storage(AHA, arrayType) y_AHA_nfft = Array(AHA * xop) y_AHA = F' * F * vec(x) @test y_AHA ≈ y_AHA_nfft rtol = 1e-2 @@ -295,6 +317,7 @@ function testNFFT3d(N=12;arrayType = Array) # test AHA with Toeplitz F_nfft.toeplitz = true AHA = normalOperator(F_nfft) + check_storage(AHA, arrayType) y_AHA_nfft_1 = Array(AHA * xop) y_AHA_nfft_2 = Array(adjoint(F_nfft) * F_nfft * xop) y_AHA = F' * F * vec(x) @@ -315,9 +338,13 @@ function testDiagOp(N=32,K=2;arrayType = Array,scheduler = DynamicScheduler()) blocks = [block for k = 1:K] op1 = DiagOp(blocks; scheduler = scheduler) + check_storage(op1, arrayType) op2 = DiagOp(blocks...; scheduler = scheduler) + check_storage(op2, arrayType) op3 = DiagOp(block, K; scheduler = scheduler) + check_storage(op3, arrayType) op4 = DiagOp(@view blocks[1:K]; scheduler = scheduler) + check_storage(op4, arrayType) # Operations @@ -406,6 +433,7 @@ function testRadonOp(N=32;arrayType = Array) angles = collect(range(0, pi, 100)) op = RadonOp(eltype(x); shape = (N, N), angles, geometry = geom, S = typeof(xop)) + check_storage(op, arrayType) y = Array(radon(x, angles; geometry = geom)) y1 = reshape(Array(op * xop), size(y)...)