-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_input.py
More file actions
63 lines (46 loc) · 1.64 KB
/
data_input.py
File metadata and controls
63 lines (46 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import tensorflow as tf
import hdf5storage as hdf
import param
num_labels = param.num_labels
def decode(serialized_example):
features = tf.parse_single_example(
serialized_example,
features={
'train/data': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64),
}
)
image = tf.decode_raw(features['train/data'], tf.uint8)
label = tf.cast(features['train/label'], tf.int32)
image = tf.reshape(image, [300, 300, 1])
return image, label
def normalize(image, label):
image = tf.cast(image, dtype=tf.float32) * (1. / 255)
return image, label
def reformat(image, label):
label = tf.one_hot(indices=label, depth=num_labels)
return image, label
def data_iterator(num_epochs, batch_size, tf_filename):
with tf.name_scope('data_input'):
dataset = tf.data.TFRecordDataset(tf_filename)
dataset = dataset.map(decode)
dataset = dataset.map(normalize)
dataset = dataset.map(reformat)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator
def load_data(filename):
"""
:param filename:
:return:
"""
images = hdf.loadmat(filename[0])
images = images[filename[0][0:-4]]
images = np.float32(np.reshape(images, [images.shape[0], images.shape[1], images.shape[2], 1]))
labels = hdf.loadmat(filename[1])
labels = np.float32(labels[filename[1][0:-4]])
labels = np.reshape(labels, [labels.shape[0], ])
labels = tf.one_hot(labels, num_labels)
return images, labels