-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_train.py
More file actions
86 lines (67 loc) · 3.06 KB
/
main_train.py
File metadata and controls
86 lines (67 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# import python modules
import tensorflow as tf
import networkx as nx
import pickle as pkl
import numpy as np
import argparse
import os
import scipy
# import custom modules
from models.gcn import *
from models.gin import *
from models.sgc import *
from models.gat import *
from models.gcn2 import *
from models.gsage import *
from models.faconv import *
from generator import generator
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float)
parser.add_argument('--hidden', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--idx', type=int)
parser.add_argument('--mtype', type=str)
parser.add_argument('--data_path', type=str)
parser.add_argument('--save_path', type=str)
args = parser.parse_args()
lr = args.lr
hidden = args.hidden
gpu = args.gpu
idx = args.idx
mtype = args.mtype
data_path = args.data_path
save_path = args.save_path
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[gpu], 'GPU') # mld4 devices range from 0 - 7
train_gen = generator(split='train', data_path=data_path)
train_dset = train_gen.tf_generator(train_gen.data_generator)
valid_gen = generator(split='valid', data_path=data_path)
valid_dset = valid_gen.tf_generator(valid_gen.data_generator)
test_gen = generator(split='test', data_path=data_path)
test_dset = test_gen.tf_generator(test_gen.data_generator)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
loss = tf.keras.losses.BinaryCrossentropy()
if mtype == 'gcn':
model = GCN(num_layers=hidden, n_classes=1, bias=True, dropout=0, g_dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'sgc':
model = SGC(num_layers=hidden, n_classes=1, bias=True, dropout=0, g_dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'gin':
model = GIN(num_layers=hidden, n_classes=1, bias=True, dropout=0, g_dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'gat':
model = GAT(num_layers=hidden, n_classes=1, bias=True, dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'gsage':
model = GSAGE(num_layers=hidden, n_classes=1, bias=True, dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'gcnii':
model = GCNII(num_layers=hidden, n_classes=1, bias=True, dropout=0, optimizer=optimizer, loss=loss)
elif mtype == 'fagcn':
model = FAGCN(num_layers=hidden, n_classes=1, bias=True, dropout=0, optimizer=optimizer, loss=loss)
else:
raise NotImplementedError()
model.train(train_dset, valid_dset, epochs=100, early_stopping=5)
test_auc, test_err, aucs, val_auc, preds = model.evaluate(valid_dset, test_dset)
performance = {'test_auc': test_auc, 'test_err': test_err, 'val_auc': val_auc, 'aucs': aucs}
if not os.path.exists(save_path):
os.mkdir(save_path)
with open(os.path.join(save_path, f'{mtype}_results_{idx}.pkl'), 'wb') as file:
pkl.dump((performance, args), file)