-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtinyimagenet_dataset.py
More file actions
115 lines (96 loc) · 4.28 KB
/
tinyimagenet_dataset.py
File metadata and controls
115 lines (96 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import shutil
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import verify_str_arg
from torchvision.datasets.utils import download_and_extract_archive
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def TinyImageNet_dataset(stage="train", download=True, root='datasets/'):
if stage == "train":
return TinyImageNetDataset(root, split="train", download=download,
transform=transforms.Compose([
transforms.RandomAffine(15, None, (0.9, 1.1)),
transforms.RandomResizedCrop(56, scale=(0.25, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.2, 0.1),
transforms.GaussianBlur(5, (0.1, 0.5)),
transforms.ToTensor(),
normalize,
]))
else:
return TinyImageNetDataset(root, split="val", download=download,
transform=transforms.Compose([
transforms.CenterCrop(56),
transforms.ToTensor(),
normalize
]))
class TinyImageNetDataset(ImageFolder):
"""
Dataset for TinyImageNet-200
credits: https://gist.github.com/lromor/bcfc69dcf31b2f3244358aea10b7a11b
"""
base_folder = 'tiny-imagenet-200'
zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de'
splits = ('train', 'val')
filename = 'tiny-imagenet-200.zip'
url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
def __init__(self, root, split='train', download=False, **kwargs):
self.data_root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", self.splits)
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
super().__init__(self.split_folder, **kwargs)
@property
def dataset_folder(self):
return os.path.join(self.data_root, self.base_folder)
@property
def split_folder(self):
return os.path.join(self.dataset_folder, self.split)
def _check_exists(self):
return os.path.exists(self.split_folder)
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
def download(self):
if self._check_exists():
return
download_and_extract_archive(
self.url, self.data_root, filename=self.filename,
remove_finished=True, md5=self.zip_md5)
assert 'val' in self.splits
normalize_tin_val_folder_structure(
os.path.join(self.dataset_folder, 'val'))
def normalize_tin_val_folder_structure(path,
images_folder='images',
annotations_file='val_annotations.txt'):
# Check if files/annotations are still there to see
# if we already run reorganize the folder structure.
images_folder = os.path.join(path, images_folder)
annotations_file = os.path.join(path, annotations_file)
# Exists
if not os.path.exists(images_folder) \
and not os.path.exists(annotations_file):
if not os.listdir(path):
raise RuntimeError('Validation folder is empty.')
return
# Parse the annotations
with open(annotations_file) as f:
for line in f:
values = line.split()
img = values[0]
label = values[1]
img_file = os.path.join(images_folder, values[0])
label_folder = os.path.join(path, label)
os.makedirs(label_folder, exist_ok=True)
try:
shutil.move(img_file, os.path.join(label_folder, img))
except FileNotFoundError:
continue
os.sync()
assert not os.listdir(images_folder)
shutil.rmtree(images_folder)
os.remove(annotations_file)
os.sync()