-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_dgr_frequency.py
More file actions
71 lines (51 loc) · 2.4 KB
/
_dgr_frequency.py
File metadata and controls
71 lines (51 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from machinable import Interface, get
from matplotlib import pyplot as plt
class DgrateFrequency(Interface):
def launch(self):
experiment = get(
"interface.sopt_dgrate",
{
"dopt_params": {
"n_epochs": 10,
},
},
).future()
if not experiment:
return
selected_solution = experiment.get_best()["x"][1]
network_model = experiment.get_model()
output = network_model.run()
params = network_model.pars
g, b, m, h = (output[k] for k in ["g", "b", "m", "h"])
fig, axs = plt.subplots(5, 2, figsize=(4, 5))
axs[0, 0].plot(params["range_t"], h, color="0.5", label="HIPP")
axs[0, 0].set_ylabel("HIPP")
axs[1, 0].plot(params["range_t"], b, color="0.5", label="BC")
axs[1, 0].set_ylabel("BC")
axs[2, 0].plot(params["range_t"], m, color="0.5", label="MC")
axs[2, 0].set_ylabel("MC")
axs[3, 0].plot(params["range_t"], g, color="0.5", label="GC")
axs[3, 0].set_ylabel("GC")
axs[4, 0].plot(params["range_t"], params["PP"], color="0.5", label="PP")
axs[4, 0].set_ylabel("PP")
axs[4, 0].set_xlabel("Time (ms)")
h_freqs, h_psd, h_peak_index = network_model.compute_PSD(h)
axs[0, 1].plot(h_freqs, h_psd, linewidth=3)
axs[0, 1].set_title("PSD (peak: %.3g Hz)" % (h_freqs[h_peak_index]))
axs[0, 1].set_ylabel("Power Spectral Density (dB/Hz)")
b_freqs, b_psd, b_peak_index = network_model.compute_PSD(b)
axs[1, 1].plot(b_freqs, b_psd, linewidth=3)
axs[1, 1].set_title("PSD (peak: %.3g Hz)" % (b_freqs[b_peak_index]))
m_freqs, m_psd, m_peak_index = network_model.compute_PSD(m)
axs[2, 1].plot(m_freqs, m_psd, linewidth=3)
axs[2, 1].set_title("PSD (peak: %.3g Hz)" % (m_freqs[m_peak_index]))
g_freqs, g_psd, g_peak_index = network_model.compute_PSD(g)
axs[3, 1].plot(g_freqs, g_psd, linewidth=3)
axs[3, 1].set_title("PSD (peak: %.3g Hz)" % (g_freqs[g_peak_index]))
pp_freqs, pp_psd, pp_peak_index = network_model.compute_PSD(params["PP"])
axs[4, 1].plot(pp_freqs, pp_psd, linewidth=3)
axs[4, 1].set_title("PSD (peak: %.3g Hz)" % (pp_freqs[pp_peak_index]))
axs[4, 1].set_xlabel("Frequency (Hz)")
fig.tight_layout()
fig.align_ylabels()
plt.show()