diff --git a/src/shared.jl b/src/shared.jl index 3ee8aaa..71098b0 100644 --- a/src/shared.jl +++ b/src/shared.jl @@ -30,7 +30,11 @@ function Base.show(io::IO, st::stFlux{Dim,Dep}) where {Dim,Dep} layers = st.mainChain.layers σ = st.settings[:σ] Nd = ndims(st) - nFilters = [length(layers[i].weight) - 1 for i = 1:3:(3*Dep)] + if Nd == 1 + nFilters = [length(layers[i].weight) - 1 for i = 1:3:(3*Dep)] + else # 2D case + nFilters = [size(layers[i].weight)[end] - 1 for i = 1:3:(3*Dep)] + end batchSize = getBatchSize(layers[1]) print(io, "stFlux{Nd=$(Nd), m=$(Dep), filters=$(nFilters), σ = " * "$(σ), batchSize = $(batchSize), normalize = $(st.normalize)}") diff --git a/src/transform.jl b/src/transform.jl index ee2d64b..64ce56c 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -101,7 +101,6 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2, end poolBy = makeTuple(m, poolBy) # Defined in ~/shared.jl argsToEach = processArgs(m + 1, kwargs) # also in ~/shared.jl - listOfSizes = [(inputSize..., ntuple(i -> 1, max(i - 1, 0))...) for i = 0:m] interstitial = Array{Any,1}(undef, 3 * (m + 1) - 2) #= `interstitial` is an array of 3 functions per layer: @@ -111,7 +110,13 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2, # first transform interstitial[3*i-2] = dispatchLayer(listOfSizes[i], Val(Nd); σ=identity, argsToEach[i]...) - nFilters = length(interstitial[3*i-2].weight) + #################### TO UPDATE ###################### + if Nd == 1 + nFilters = length(interstitial[3*i-2].weight) + else + nFilters = size(interstitial[3 * i - 2].weight)[end] + end + #################### TO UPDATE ###################### pooledSize = poolSize(poolBy[i], listOfSizes[i][1:Nd]) # in ~/pool.jl #= then throw away the averaging (we'll pick it up in the actual transform) @@ -145,7 +150,6 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2, chacha = Chain(interstitial...) # Chain is from `Flux.jl` outputSizes = ([(map(poolSize, outputPool[ii], x[1:Nd])..., x[(Nd+1):end]...) for (ii, x) in enumerate(listOfSizes)]...,) - # record the settings used pretty kludgy settings = Dict(:outputPool => outputPool, :poolBy => poolBy, :σ => σ, :flatten => flatten, (argsToEach...)...) diff --git a/src/utilities.jl b/src/utilities.jl index 852d566..22ac36d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -235,9 +235,11 @@ function size(st::stFlux) l = st.mainChain[1] if typeof(l.fftPlan) <: Tuple sz = l.fftPlan[2].sz + es = originalSize(sz[1:ndims(l.weight[1])], l.bc) else sz = l.fftPlan.sz + es = originalSize(sz[1:ndims(l.weight) - 1], l.bc) end - es = originalSize(sz[1:ndims(l.weight[1])], l.bc) + return es end diff --git a/test/fluxtests.jl b/test/fluxtests.jl index c1b7820..e986d77 100644 --- a/test/fluxtests.jl +++ b/test/fluxtests.jl @@ -99,6 +99,49 @@ @test res1[1:32*3, 1] ≈ reshape(res[0][:, :, 1], (32 * 3,)) end + nFilters = [1, 12, 12, 12] + @testset "2D basics" begin + n_init_channels=2 + batch_size = 2 + init = 10 .+ randn(64, 64, n_init_channels, batch_size); + sst = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,)) + res = sst(init) + @test length(res.output) == 2 + 1 # same + @test size(res.output[1]) == (32, 32, n_init_channels, batch_size) + @test minimum(abs.(res.output[1])) > 0 + @test size(res.output[2]) == (22, 22, n_init_channels * nFilters[2], 2) + @test minimum(abs.(res.output[2])) > 0 + @test size(res.output[3]) == (14, 14, nFilters[3], n_init_channels * nFilters[2], 2) + @test minimum(abs.(res.output[3])) > 0 + totalSize = 32^2 * n_init_channels + 22^2 * n_init_channels * nFilters[2] + 14^2 * nFilters[3] * n_init_channels * nFilters[2] + smooshed = ScatteringTransform.flatten(res) + @test size(smooshed) == (totalSize, 2) + sst1 = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,), flatten=true) + res1 = sst1(init) + @test res1 isa Array{Float32,2} + @test size(res1) == (totalSize, 2) + @test res1[1:32^2*n_init_channels, 1] ≈ reshape(res[0][:, :, :, 1], (32^2 * n_init_channels,)) + + sst = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,)) + res = sst(init) + # @test + @test length(res.output) == 2 + 1 # same + @test size(res.output[1]) == (32, 32, n_init_channels, batch_size) + @test minimum(abs.(res.output[1])) > 0 + @test size(res.output[2]) == (22, 22, n_init_channels * nFilters[2], 2) + @test minimum(abs.(res.output[2])) > 0 + @test size(res.output[3]) == (14, 14, nFilters[3], n_init_channels * nFilters[2], 2) + @test minimum(abs.(res.output[3])) > 0 + totalSize = 32^2 * n_init_channels + 22^2 * n_init_channels * nFilters[2] + 14^2 * nFilters[3] * n_init_channels * nFilters[2] + smooshed = ScatteringTransform.flatten(res) + @test size(smooshed) == (totalSize, 2) + sst1 = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,), flatten=true) + res1 = sst1(init) + @test res1 isa Array{Float32,2} + @test size(res1) == (totalSize, 2) + @test res1[1:32^2*n_init_channels, 1] ≈ reshape(res[0][:, :, :, 1], (32^2 * n_init_channels,)) + end + nFilters = [1, 10, 9] @testset "1D integer pooling" begin stEx = stFlux((131, 1, 1), 2, poolBy=3)