Skip to content

Commit 238c84b

Browse files
test: add tests for ngrc
1 parent 95654cc commit 238c84b

File tree

7 files changed

+124
-30
lines changed

7 files changed

+124
-30
lines changed
Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
using Test
2+
using Random
3+
using ReservoirComputing
4+
using Static
5+
using LinearAlgebra
6+
17
@testset "DelayESN" begin
28
rng = MersenneTwister(123)
39

@@ -10,12 +16,12 @@
1016
desn = DelayESN(in_dims, res_dims, out_dims;
1117
num_delays = num_delays, stride = 1)
1218

13-
@test desn isa ReservoirComputer
19+
@test desn isa DelayESN
1420

1521
reservoir = desn.reservoir
1622
@test reservoir isa StatefulLayer
1723

18-
mods = desn.state_modifiers
24+
mods = desn.states_modifiers
1925
@test mods isa Tuple
2026
@test !isempty(mods)
2127
@test first(mods) isa DelayLayer
@@ -31,28 +37,6 @@
3137
@test Int(ro.out_dims) == out_dims
3238
end
3339

34-
@testset "setup and forward pass shapes" begin
35-
in_dims = 4
36-
res_dims = 10
37-
out_dims = 3
38-
num_delays = 1
39-
40-
desn = DelayESN(in_dims, res_dims, out_dims;
41-
num_delays = num_delays, stride = 1)
42-
43-
ps, st = setup(rng, desn)
44-
45-
x = rand(rng, Float32, in_dims)
46-
y, st2 = desn(x, ps, st)
47-
@test size(y) == (out_dims,)
48-
49-
X = rand(rng, Float32, in_dims, 7)
50-
Y, st3 = desn(X, ps, st2)
51-
@test size(Y) == (out_dims, 7)
52-
53-
@test propertynames(st3) == propertynames(st2)
54-
end
55-
5640
@testset "num_delays changes readout input dim" begin
5741
in_dims = 2
5842
res_dims = 6
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Test
22
using Random
33
using ReservoirComputing
44
using Static
5+
using LinearAlgebra
56

67
const _I32 = (m, n) -> Matrix{Float32}(I, m, n)
78
const _Z32 = m -> zeros(Float32, m)
@@ -35,11 +36,11 @@ end
3536
state_modifiers = (),
3637
readout_activation = identity)
3738
ps, st = setup(rng, hesn)
38-
@test haskey(ps, :cell) && haskey(ps, :knowledge_model) &&
39+
@test haskey(ps, :reservoir) && haskey(ps, :knowledge_model) &&
3940
haskey(ps, :states_modifiers) && haskey(ps, :readout)
40-
@test size(ps.cell.input_matrix) == (res_dims, in_dims + km_dims)
41+
@test size(ps.reservoir.input_matrix) == (res_dims, in_dims + km_dims)
4142
@test size(ps.readout.weight) == (out_dims, res_dims + km_dims)
42-
@test haskey(st, :cell) && haskey(st, :knowledge_model) &&
43+
@test haskey(st, :reservoir) && haskey(st, :knowledge_model) &&
4344
haskey(st, :states_modifiers) && haskey(st, :readout)
4445
end
4546

test/models/test_ngrc.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using Random
2+
using ReservoirComputing
3+
using LuxCore
4+
using Static
5+
6+
@testset "NGRC" begin
7+
rng = MersenneTwister(1234)
8+
9+
const_feature = x -> Float32[1.0]
10+
square_feature = x -> x .^ 2
11+
12+
@testset "constructor & composition" begin
13+
ngrc = NGRC(3, 2; num_delays = 1, stride = 2,
14+
features = (const_feature, square_feature), include_input = True(),
15+
init_delay = zeros32, readout_activation = tanh)
16+
17+
@test ngrc isa NGRC
18+
@test ngrc.reservoir isa DelayLayer
19+
@test ngrc.readout isa LinearReadout
20+
21+
dl = ngrc.reservoir
22+
@test dl.in_dims == 3
23+
@test dl.num_delays == 1
24+
@test dl.stride == 2
25+
26+
@test !isempty(ngrc.states_modifiers)
27+
first_mod = getfield(ngrc.states_modifiers, 1)
28+
@test first_mod isa NonlinearFeaturesLayer
29+
end
30+
31+
@testset "initialparameters & initialstates" begin
32+
ngrc = NGRC(3, 2; num_delays = 1, features = (square_feature,),
33+
include_input = True())
34+
35+
ps = initialparameters(rng, ngrc)
36+
st = initialstates(rng, ngrc)
37+
38+
@test hasproperty(ps, :reservoir)
39+
@test hasproperty(ps, :states_modifiers)
40+
@test hasproperty(ps, :readout)
41+
42+
@test hasproperty(st, :reservoir)
43+
@test hasproperty(st, :states_modifiers)
44+
@test hasproperty(st, :readout)
45+
46+
@test ps.readout.weight isa AbstractArray
47+
end
48+
49+
@testset "forward pass: vector input" begin
50+
ngrc = NGRC(3, 2; num_delays = 1, features = (square_feature,),
51+
include_input = True())
52+
53+
ps, st = setup(rng, ngrc)
54+
55+
x = rand(Float32, 3)
56+
y, st2 = ngrc(x, ps, st)
57+
58+
@test size(y) == (2,)
59+
@test propertynames(st2) == propertynames(st)
60+
end
61+
62+
@testset "forward pass: matrix input via collectstates" begin
63+
ngrc = NGRC(3, 2; num_delays = 1, features = (square_feature,),
64+
include_input = True())
65+
66+
ps, st = setup(rng, ngrc)
67+
68+
X = rand(Float32, 3, 10)
69+
states, st2 = collectstates(ngrc, X, ps, st)
70+
71+
@test size(states, 2) == size(X, 2)
72+
@test size(states, 1) == ngrc.readout.in_dims
73+
74+
@test propertynames(st2) == propertynames(st)
75+
end
76+
77+
@testset "simple 1D linear system learning" begin
78+
# x_{t+1} = a * x_t with a = 0.8
79+
a = 0.8f0
80+
T = 200
81+
x = zeros(Float32, T)
82+
x[1] = 1.0f0
83+
for t in 1:(T - 1)
84+
x[t + 1] = a * x[t]
85+
end
86+
87+
X_in = reshape(x[1:(end - 1)], 1, :)
88+
Y_out = reshape(x[2:end], 1, :)
89+
90+
ngrc = NGRC(1, 1; num_delays = 0, stride = 1, features = (),
91+
include_input = True(), ro_dims = 1)
92+
93+
ps, st = setup(rng, ngrc)
94+
95+
ps_tr, st_tr = train!(ngrc, X_in, Y_out, ps, st;
96+
train_method = StandardRidge(1e-6))
97+
98+
@test hasproperty(ps_tr, :readout)
99+
w = ps_tr.readout.weight
100+
@test size(w) == (1, 1)
101+
@test isapprox(w[1, 1], a; atol = 0.05)
102+
end
103+
end

test/runtests.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ end
1313
end
1414

1515
@testset "Echo State Networks" begin
16-
@safetestset "ESN Initializers" include("esn/test_inits.jl")
17-
@safetestset "ESN model" include("esn/test_esn.jl")
18-
@safetestset "DeepESN model" include("esn/test_esn_deep.jl")
16+
@safetestset "ESN Initializers" include("test_inits.jl")
17+
@safetestset "ESN model" include("models/test_esn.jl")
18+
@safetestset "DeepESN model" include("models/test_esn_deep.jl")
19+
@safetestset "DelayESN model" include("models/test_esn_delay.jl")
20+
@safetestset "HybridESN model" include("models/test_esn_hybrid.jl")
21+
end
22+
23+
@testset "Next Generation Reservoir Computing" begin
24+
@safetestset "NGRC model" include("models/test_ngrc.jl")
1925
end

0 commit comments

Comments
 (0)