-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrnns.py
More file actions
54 lines (42 loc) · 1.85 KB
/
rnns.py
File metadata and controls
54 lines (42 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, n_layers, n_nodes):
super(LSTMModel, self).__init__()
c_0 = nn.LSTMCell(input_dim, hidden_dim)
self.cells = nn.ModuleList([nn.LSTMCell(hidden_dim, hidden_dim) for _ in range(n_layers-1)])
self.cells.insert(0, c_0)
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.n_nodes = n_nodes
def forward(self, inps, h0_list):
prev_h, _ = self.cells[0](inps, h0_list[0])
h_list = [(prev_h, _)]
for i, (l, h_c) in enumerate(zip(self.cells, h0_list)):
if i == 0: continue
(prev_h, c) = l(prev_h, h_c)
h_list.append((prev_h, c))
return h_list
def init__hidd(self, device):
return [(torch.ones(self.n_nodes, self.hidden_dim).to(device), torch.ones(self.n_nodes, self.hidden_dim).to(device)) for _ in range(self.n_layers)]
class GRUModel(nn.Module):
def __init__(self, input_dim, hidden_dim, n_layers, n_nodes) -> None:
super(GRUModel, self).__init__()
c_0 = nn.GRUCell(input_dim, hidden_dim)
self.cells = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(n_layers-1)])
self.cells.insert(0, c_0)
self.n_layers = n_layers
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.n_nodes = n_nodes
def forward(self, inps, h0_list):
prev_h = self.cells[0](inps, h0_list[0])
h_list = [prev_h]
for i, (l, h_c) in enumerate(zip(self.cells, h0_list)):
if i == 0: continue
prev_h = l(prev_h, h_c)
h_list.append((prev_h))
return h_list
def init__hidd(self, device):
h0s = [torch.ones(self.n_nodes, self.hidden_dim).to(device) for _ in range(self.n_layers)]
return h0s