forked from Sunnan191/EviSEC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathedge_cls_tasker.py
More file actions
94 lines (73 loc) · 2.77 KB
/
edge_cls_tasker.py
File metadata and controls
94 lines (73 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import taskers_utils as tu
import utils as u
class Edge_Cls_Tasker():
def __init__(self,args,dataset,ood_mode=None):
self.data = dataset
#max_time for link pred should be one before
self.max_time = dataset.max_time
self.args = args
self.num_classes = dataset.num_classes
if not args.use_1_hot_node_feats:
self.feats_per_node = dataset.feats_per_node
self.get_node_feats = self.build_get_node_feats(args,dataset, ood_mode)
self.prepare_node_feats = self.build_prepare_node_feats(args,dataset, ood_mode)
self.is_static = False
def build_prepare_node_feats(self,args,dataset,ood_mode):
if args.use_2_hot_node_feats or args.use_1_hot_node_feats:
def prepare_node_feats(node_feats):
return u.sparse_prepare_tensor(node_feats,
torch_size= [dataset.num_nodes,
self.feats_per_node])
else:
print("log")
prepare_node_feats = self.data.prepare_node_feats
return prepare_node_feats
def build_get_node_feats(self,args,dataset,ood_mode):
if args.use_2_hot_node_feats:
max_deg_out, max_deg_in = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg_out + max_deg_in
def get_node_feats(adj):
return tu.get_2_hot_deg_feats(adj,
max_deg_out,
max_deg_in,
dataset.num_nodes)
elif args.use_1_hot_node_feats:
max_deg,_ = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg
def get_node_feats(adj):
feats = tu.get_1_hot_deg_feats(adj, max_deg, dataset.num_nodes)
if ood_mode == "FI":
# print(f"FI {args.data}")
feats["idx"][:, 1] = torch.randint(1, max_deg, feats["idx"][:, 1].size(), dtype=torch.int64)
return feats
else:
def get_node_feats(adj):
return dataset.nodes_feats
return get_node_feats
def get_sample(self,idx,test):
hist_adj_list = []
hist_ndFeats_list = []
hist_mask_list = []
hist_adj_list_unnormalized = []
for i in range(idx - self.args.num_hist_steps, idx+1):
cur_adj = tu.get_sp_adj(edges = self.data.edges,
time = i,
weighted = True,
time_window = self.args.adj_mat_time_window)
node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
node_feats = self.get_node_feats(cur_adj)
cur_adj_unnormalized = cur_adj
cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes)
hist_adj_list.append(cur_adj)
hist_ndFeats_list.append(node_feats)
hist_mask_list.append(node_mask)
hist_adj_list_unnormalized.append(cur_adj_unnormalized)
label_adj = tu.get_edge_labels(edges = self.data.edges,
time = idx)
return {'idx': idx,
'hist_adj_list': hist_adj_list,
'hist_ndFeats_list': hist_ndFeats_list,
'label_sp': label_adj,
'node_mask_list': hist_mask_list,
'hist_adj_list_u': hist_adj_list_unnormalized}