-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_hparams.py
More file actions
159 lines (121 loc) · 4.48 KB
/
make_hparams.py
File metadata and controls
159 lines (121 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import yaml
import copy
import numpy as np
import re
import glob
def append_config_ckpt(config): # will mutate config
if 'checkpoint_pth' not in config:
print('no chekcpoint found')
return config
ckpt_pth = config['checkpoint_pth']
ckpt_pth = ckpt_pth.split('/')[-1]
splits = ckpt_pth.split('_')
splits = "_".join(splits[2:])
#splits = re.split(r'(lr=|class_decoder_dims=|final_embedding_dim=|gnn_channels=|gnn_hidden_dim=)', splits)
splits = re.split(r'(class_decoder_dims=|final_embedding_dim=|gnn_channels=|gnn_hidden_dim=|formula_dims=|formula_nheads=|formula_num_layers=|formula_ff=)', splits)
splits = splits[1:]
splits = [e[:-1] if e[-1] == "_" else e for e in splits]
for curr_key, curr_value in zip(splits[::2], splits[1::2]):
curr_key = curr_key[:-1]
if "" == curr_value:
curr_value = []
elif "_" in curr_value:
curr_value = curr_value.split("_")
curr_value = [int(e) if e.isdigit() else float(e) for e in curr_value]
elif curr_value.isdigit():
curr_value = int(curr_value)
else:
curr_value = float(curr_value)
# override type if needed
if curr_key in ['gnn_channels', 'class_decoder_dims', 'formula_dims']:
if type(curr_value) != list:
curr_value = [curr_value]
if curr_key not in config:
config[curr_key] = curr_value
return config
with open('config_base/params_transformer_rank.yaml') as f:
ori_params = yaml.load(f, Loader=yaml.FullLoader)
with open('config_base/hparams_transformer_rank.yaml') as f:
config_array = yaml.load(f, Loader=yaml.FullLoader)
if 'load_checkpoint_folder' in config_array:
checkpoint_folder = config_array['load_checkpoint_folder']
del config_array['load_checkpoint_folder']
config_array['checkpoint_pth'] = [curr_path for curr_path in glob.glob('%s/*'%checkpoint_folder)]
"""
config_array = {
"lr": [1e-3, 1e-4, 1e-5],
"class_decoder_dims": [
[],
[256],
[256, 256],
[256, 256, 256],
[128],
[128, 128],
[128, 128, 128]
],
"final_embedding_dim": [512, 256, 128, 64],
"formula_dims": [
[256],
[256, 256],
[256, 256, 256],
[128],
[128, 128],
[128, 128, 128]
],
#"gnn_channels": [
# [128],
# [128, 128],
# [128, 128, 128],
# [256],
# [256, 256],
# [256, 256, 256]
#],
#"gnn_hidden_dim": [512, 256, 128]
#'checkpoint_pth': [curr_path for curr_path in glob.glob('results/base_gnn_hparams_test/*') if curr_path.split('_')[-1] != 'test']
}
"""
num_tests = 24
# next we will simulate a grid search, but randomly sample a subset of the grid search to actually attempt
keys = [key for key in config_array]
total_possible = np.prod([len(config_array[key]) for key in keys])
print('max available:', total_possible)
sample_idx = np.random.choice(total_possible, num_tests, replace=False)
all_configs = []
for curr_idx in sample_idx:
curr_config = {}
j = curr_idx
for key in keys:
i = j%len(config_array[key])
j = j//len(config_array[key])
curr_config[key] = config_array[key][i]
curr_config = append_config_ckpt(curr_config)
print(curr_config)
all_configs.append(curr_config)
config_folders = ['config4', 'config5', 'config6', 'config7']
gpus = [3,3,4,5]
prefix = "hparam_transformer"
assert len(config_folders) == len(gpus)
num_folders = len(config_folders)
splits = np.array_split(all_configs, num_folders)
for i in range(num_folders):
curr_folder_name = config_folders[i]
curr_gpu = gpus[i]
for j, curr_config in enumerate(splits[i]):
param_name = prefix
for key in curr_config:
if key == 'checkpoint_pth':
continue
value = curr_config[key]
if type(value) == list:
value = "_".join([str(e) for e in value])
param_name += "_%s=%s"%(key, value)
else:
param_name += "_%s=%s"%(key, str(value))
#param_name = '"%s"'%param_name
curr_param = copy.deepcopy(ori_params)
curr_param['run_name'] = param_name
curr_param['devices'] = [curr_gpu]
for key in curr_config:
curr_param[key] = curr_config[key]
with open('%s/%d_%s.yaml'%(curr_folder_name, j, param_name), 'w') as f:
yaml.dump(curr_param, f)