Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ function Base.show(io::IO, st::stFlux{Dim,Dep}) where {Dim,Dep}
layers = st.mainChain.layers
σ = st.settings[:σ]
Nd = ndims(st)
nFilters = [length(layers[i].weight) - 1 for i = 1:3:(3*Dep)]
if Nd == 1
nFilters = [length(layers[i].weight) - 1 for i = 1:3:(3*Dep)]
else # 2D case
nFilters = [size(layers[i].weight)[end] - 1 for i = 1:3:(3*Dep)]
end
batchSize = getBatchSize(layers[1])
print(io, "stFlux{Nd=$(Nd), m=$(Dep), filters=$(nFilters), σ = " *
"$(σ), batchSize = $(batchSize), normalize = $(st.normalize)}")
Expand Down
10 changes: 7 additions & 3 deletions src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2,
end
poolBy = makeTuple(m, poolBy) # Defined in ~/shared.jl
argsToEach = processArgs(m + 1, kwargs) # also in ~/shared.jl

listOfSizes = [(inputSize..., ntuple(i -> 1, max(i - 1, 0))...) for i = 0:m]
interstitial = Array{Any,1}(undef, 3 * (m + 1) - 2)
#= `interstitial` is an array of 3 functions per layer:
Expand All @@ -111,7 +110,13 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2,
# first transform
interstitial[3*i-2] = dispatchLayer(listOfSizes[i], Val(Nd); σ=identity,
argsToEach[i]...)
nFilters = length(interstitial[3*i-2].weight)
#################### TO UPDATE ######################
if Nd == 1
nFilters = length(interstitial[3*i-2].weight)
else
nFilters = size(interstitial[3 * i - 2].weight)[end]
end
#################### TO UPDATE ######################

pooledSize = poolSize(poolBy[i], listOfSizes[i][1:Nd]) # in ~/pool.jl
#= then throw away the averaging (we'll pick it up in the actual transform)
Expand Down Expand Up @@ -145,7 +150,6 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2,
chacha = Chain(interstitial...) # Chain is from `Flux.jl`
outputSizes = ([(map(poolSize, outputPool[ii], x[1:Nd])..., x[(Nd+1):end]...)
for (ii, x) in enumerate(listOfSizes)]...,)

# record the settings used pretty kludgy
settings = Dict(:outputPool => outputPool, :poolBy => poolBy, :σ => σ,
:flatten => flatten, (argsToEach...)...)
Expand Down
4 changes: 3 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,11 @@ function size(st::stFlux)
l = st.mainChain[1]
if typeof(l.fftPlan) <: Tuple
sz = l.fftPlan[2].sz
es = originalSize(sz[1:ndims(l.weight[1])], l.bc)
else
sz = l.fftPlan.sz
es = originalSize(sz[1:ndims(l.weight) - 1], l.bc)
end
es = originalSize(sz[1:ndims(l.weight[1])], l.bc)

return es
end
43 changes: 43 additions & 0 deletions test/fluxtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,49 @@
@test res1[1:32*3, 1] ≈ reshape(res[0][:, :, 1], (32 * 3,))
end

nFilters = [1, 12, 12, 12]
@testset "2D basics" begin
n_init_channels=2
batch_size = 2
init = 10 .+ randn(64, 64, n_init_channels, batch_size);
sst = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,))
res = sst(init)
@test length(res.output) == 2 + 1 # same
@test size(res.output[1]) == (32, 32, n_init_channels, batch_size)
@test minimum(abs.(res.output[1])) > 0
@test size(res.output[2]) == (22, 22, n_init_channels * nFilters[2], 2)
@test minimum(abs.(res.output[2])) > 0
@test size(res.output[3]) == (14, 14, nFilters[3], n_init_channels * nFilters[2], 2)
@test minimum(abs.(res.output[3])) > 0
totalSize = 32^2 * n_init_channels + 22^2 * n_init_channels * nFilters[2] + 14^2 * nFilters[3] * n_init_channels * nFilters[2]
smooshed = ScatteringTransform.flatten(res)
@test size(smooshed) == (totalSize, 2)
sst1 = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,), flatten=true)
res1 = sst1(init)
@test res1 isa Array{Float32,2}
@test size(res1) == (totalSize, 2)
@test res1[1:32^2*n_init_channels, 1] ≈ reshape(res[0][:, :, :, 1], (32^2 * n_init_channels,))

sst = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,))
res = sst(init)
# @test
@test length(res.output) == 2 + 1 # same
@test size(res.output[1]) == (32, 32, n_init_channels, batch_size)
@test minimum(abs.(res.output[1])) > 0
@test size(res.output[2]) == (22, 22, n_init_channels * nFilters[2], 2)
@test minimum(abs.(res.output[2])) > 0
@test size(res.output[3]) == (14, 14, nFilters[3], n_init_channels * nFilters[2], 2)
@test minimum(abs.(res.output[3])) > 0
totalSize = 32^2 * n_init_channels + 22^2 * n_init_channels * nFilters[2] + 14^2 * nFilters[3] * n_init_channels * nFilters[2]
smooshed = ScatteringTransform.flatten(res)
@test size(smooshed) == (totalSize, 2)
sst1 = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,), flatten=true)
res1 = sst1(init)
@test res1 isa Array{Float32,2}
@test size(res1) == (totalSize, 2)
@test res1[1:32^2*n_init_channels, 1] ≈ reshape(res[0][:, :, :, 1], (32^2 * n_init_channels,))
end

nFilters = [1, 10, 9]
@testset "1D integer pooling" begin
stEx = stFlux((131, 1, 1), 2, poolBy=3)
Expand Down
Loading