-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataHandler.py
More file actions
59 lines (54 loc) · 2 KB
/
DataHandler.py
File metadata and controls
59 lines (54 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import tensorflow as tf
from PIL import Image
class DataHandler:
def __init__(self, dataset_path, dataset_name, dims, batch_size, class_names, val_split=0.2, seed=1337):
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.dims = dims
self.batch_size = batch_size
self.class_names = class_names
self.val_split = val_split
self.seed = seed
self.train_ds = None
self.val_ds = None
# init
self.init_dataset()
def filter_in(self, types, fpath):
for ext in types:
try:
img = Image.open(fpath)
exif_data = img._getexif()
img.verify()
except:
os.remove(fpath)
return 1
return 0
def remove_invalid(self):
num_skipped = 0
for folder_name in self.class_names:
folder_path = os.path.join(self.dataset_path + self.dataset_name, folder_name)
print(f"folder_path:\t{folder_path}")
for fname in os.listdir(folder_path):
if fname.endswith('.jpg'):
fpath = os.path.join(folder_path, fname)
num_skipped += self.filter_in(["jpg"], fpath)
print("Deleted %d images" % num_skipped)
def init_dataset(self):
self.remove_invalid()
self.train_ds = tf.keras.preprocessing.image_dataset_from_directory(
os.path.join(self.dataset_path, self.dataset_name),
validation_split=self.val_split,
subset="training",
seed=self.seed,
image_size=self.dims,
batch_size=self.batch_size,
)
self.val_ds = tf.keras.preprocessing.image_dataset_from_directory(
os.path.join(self.dataset_path, self.dataset_name),
validation_split=self.val_split,
subset="validation",
seed=self.seed,
image_size=self.dims,
batch_size=self.batch_size,
)