forked from FighterLYL/GraphNeuralNetwork
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
140 lines (120 loc) · 5.85 KB
/
data.py
File metadata and controls
140 lines (120 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import os.path as osp
import pickle
import numpy as np
import itertools
import scipy.sparse as sp
import urllib
from collections import namedtuple
Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
'train_mask', 'val_mask', 'test_mask'])
class CoraData(object):
download_url = "https://raw.githubusercontent.com/kimiyoung/planetoid/master/data"
filenames = ["ind.cora.{}".format(name) for name in
['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]
def __init__(self, data_root="cora", rebuild=False):
"""Cora数据,包括数据下载,处理,加载等功能
当数据的缓存文件存在时,将使用缓存文件,否则将下载、进行处理,并缓存到磁盘
处理之后的数据可以通过属性 .data 获得,它将返回一个数据对象,包括如下几部分:
* x: 节点的特征,维度为 2708 * 1433,类型为 np.ndarray
* y: 节点的标签,总共包括7个类别,类型为 np.ndarray
* adjacency_dict: 邻接信息,,类型为 dict
* train_mask: 训练集掩码向量,维度为 2708,当节点属于训练集时,相应位置为True,否则False
* val_mask: 验证集掩码向量,维度为 2708,当节点属于验证集时,相应位置为True,否则False
* test_mask: 测试集掩码向量,维度为 2708,当节点属于测试集时,相应位置为True,否则False
Args:
-------
data_root: string, optional
存放数据的目录,原始数据路径: {data_root}/raw
缓存数据路径: {data_root}/processed_cora.pkl
rebuild: boolean, optional
是否需要重新构建数据集,当设为True时,如果存在缓存数据也会重建数据
"""
self.data_root = data_root
save_file = osp.join(self.data_root, "processed_cora.pkl")
if osp.exists(save_file) and not rebuild:
print("Using Cached file: {}".format(save_file))
self._data = pickle.load(open(save_file, "rb"))
else:
self.maybe_download()
self._data = self.process_data()
with open(save_file, "wb") as f:
pickle.dump(self.data, f)
print("Cached file: {}".format(save_file))
@property
def data(self):
"""返回Data数据对象,包括x, y, adjacency, train_mask, val_mask, test_mask"""
return self._data
def process_data(self):
"""
处理数据,得到节点特征和标签,邻接矩阵,训练集、验证集以及测试集
引用自:https://github.com/rusty1s/pytorch_geometric
"""
print("Process data ...")
_, tx, allx, y, ty, ally, graph, test_index = [self.read_data(
osp.join(self.data_root, "raw", name)) for name in self.filenames]
train_index = np.arange(y.shape[0])
val_index = np.arange(y.shape[0], y.shape[0] + 500)
sorted_test_index = sorted(test_index)
x = np.concatenate((allx, tx), axis=0)
y = np.concatenate((ally, ty), axis=0).argmax(axis=1)
x[test_index] = x[sorted_test_index]
y[test_index] = y[sorted_test_index]
num_nodes = x.shape[0]
train_mask = np.zeros(num_nodes, dtype=np.bool)
val_mask = np.zeros(num_nodes, dtype=np.bool)
test_mask = np.zeros(num_nodes, dtype=np.bool)
train_mask[train_index] = True
val_mask[val_index] = True
test_mask[test_index] = True
adjacency_dict = graph
print("Node's feature shape: ", x.shape)
print("Node's label shape: ", y.shape)
print("Adjacency's shape: ", len(adjacency_dict))
print("Number of training nodes: ", train_mask.sum())
print("Number of validation nodes: ", val_mask.sum())
print("Number of test nodes: ", test_mask.sum())
return Data(x=x, y=y, adjacency_dict=adjacency_dict,
train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
def maybe_download(self):
save_path = os.path.join(self.data_root, "raw")
for name in self.filenames:
if not osp.exists(osp.join(save_path, name)):
self.download_data(
"{}/{}".format(self.download_url, name), save_path)
@staticmethod
def build_adjacency(adj_dict):
"""根据邻接表创建邻接矩阵"""
edge_index = []
num_nodes = len(adj_dict)
for src, dst in adj_dict.items():
edge_index.extend([src, v] for v in dst)
edge_index.extend([v, src] for v in dst)
# 去除重复的边
edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index)))
edge_index = np.asarray(edge_index)
adjacency = sp.coo_matrix((np.ones(len(edge_index)),
(edge_index[:, 0], edge_index[:, 1])),
shape=(num_nodes, num_nodes), dtype="float32")
return adjacency
@staticmethod
def read_data(path):
"""使用不同的方式读取原始数据以进一步处理"""
name = osp.basename(path)
if name == "ind.cora.test.index":
out = np.genfromtxt(path, dtype="int64")
return out
else:
out = pickle.load(open(path, "rb"), encoding="latin1")
out = out.toarray() if hasattr(out, "toarray") else out
return out
@staticmethod
def download_data(url, save_path):
"""数据下载工具,当原始数据不存在时将会进行下载"""
if not os.path.exists(save_path):
os.makedirs(save_path)
data = urllib.request.urlopen(url)
filename = os.path.split(url)[-1]
with open(os.path.join(save_path, filename), 'wb') as f:
f.write(data.read())
return True