-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_combined.py
More file actions
148 lines (102 loc) · 5.72 KB
/
inference_combined.py
File metadata and controls
148 lines (102 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from Dataset.FastDataset import FastTransFunDataset
from Utils import load_ckp, pickle_load, pickle_save
import CONSTANTS
import math, os, time
import argparse
from models.model_ablation import TFun, TFun_submodel
from Dataset.FastDataset import TestDataset
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument("--load_weights", default=False, type=bool, help='Load weights from saved model')
parser.add_argument('--label_features', type=str, default='linear', help='Sub model to train')
args = parser.parse_args()
torch.manual_seed(args.seed)
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
device = 'cuda:1'
else:
device = 'cpu'
# load all test
all_test = pickle_load(CONSTANTS.ROOT_DIR + "test/t3/test_proteins")
ontologies = ["cc", "mf", "bp"]
annotation_depths = ["LK", "NK"]
sub_models = ['full']
label_features = ['gcn']
def write_output(results, terms, filepath, cutoff=0.001):
with open(filepath, 'w') as fp:
for prt in results:
assert len(terms) == len(results[prt])
tmp = list(zip(terms, results[prt]))
tmp.sort(key = lambda x: x[1], reverse=True)
for trm, score in tmp:
if score > cutoff:
fp.write('%s\t%s\t%0.3f\n' % (prt, trm, score))
def get_term_indicies(ontology, submodel="full", label_feature="max"):
_term_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/term_indicies".format(ontology))
if ontology == 'bp':
full_term_indicies, mid_term_indicies, freq_term_indicies = _term_indicies[0], _term_indicies[5], _term_indicies[30]
rare_term_indicies_2 = torch.tensor([i for i in full_term_indicies if not i in set(mid_term_indicies)]).to(device)
rare_term_indicies = torch.tensor([i for i in mid_term_indicies if not i in set(freq_term_indicies)]).to(device)
full_term_indicies, freq_term_indicies = torch.tensor(_term_indicies[0]).to(device), torch.tensor(freq_term_indicies).to(device)
else:
full_term_indicies = _term_indicies[0]
freq_term_indicies = _term_indicies[30]
rare_term_indicies = torch.tensor([i for i in full_term_indicies if not i in set(freq_term_indicies)]).to(device)
full_term_indicies = torch.tensor(full_term_indicies).to(device)
freq_term_indicies = torch.tensor(freq_term_indicies).to(device)
rare_term_indicies_2 = None
return full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2
'''if submodel == 'full' and label_feature not in ['max', 'mean']:
term_indicies = torch.tensor(_term_indicies[0])
sub_indicies = torch.tensor(_term_indicies[threshold[ontology]])
else:
term_indicies = torch.tensor(_term_indicies[threshold[ontology]])
sub_indicies = term_indicies
sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology))
terms = [sorted_terms[i] for i in term_indicies]
return terms, term_indicies, sub_indicies'''
for annotation_depth in annotation_depths:
for ontology in ontologies:
data_pth = CONSTANTS.ROOT_DIR + "test/t3/dataset/{}_{}".format(annotation_depth, ontology)
sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology))
for sub_model in sub_models:
tst_dataset = TestDataset(data_pth=data_pth, submodel=sub_model)
tstloader = torch.utils.data.DataLoader(tst_dataset, batch_size=500, shuffle=False)
# terms, term_indicies, sub_indicies = get_term_indicies(ontology=ontology, submodel=sub_model)
full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 = get_term_indicies(ontology=ontology, submodel=sub_model)
kwargs = {
'device': device,
'ont': ontology,
'full_indicies': full_term_indicies,
'freq_indicies': freq_term_indicies,
'rare_indicies': rare_term_indicies,
'rare_indicies_2': rare_term_indicies_2,
'sub_model': sub_model,
'load_weights': True,
'label_features': "",
'group': ""
}
for label_feature in label_features:
print("Generating for {} {} {} {}".format(annotation_depth, ontology, sub_model, label_feature))
kwargs['label_features'] = label_feature
ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}_{}_combined/'.format(ontology, sub_model, label_feature)
ckp_pth = ckp_dir + "current_checkpoint.pt"
model = TFun(**kwargs)
# load model
if label_feature != 'max' and label_feature != 'mean':
model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True)
model.to(device)
model.eval()
results = {}
for data in tstloader:
_features, _proteins = data[:4], data[4]
output, _ = model(_features)
output = torch.index_select(output, 1, full_term_indicies)
output = output.tolist()
for i, j in zip(_proteins, output):
results[i] = j
terms = [sorted_terms[i] for i in full_term_indicies]
filepath = 'evaluation/predictions/transfew/{}_{}_{}_combined_{}.tsv'.format(annotation_depth, ontology, sub_model, label_feature)
write_output(results, terms, filepath, cutoff=0.01)