-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
52 lines (31 loc) · 1.34 KB
/
utils.py
File metadata and controls
52 lines (31 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torchvision.transforms as TF
from random import randint, choice
class TripletDataset(TensorDataset):
def __init__(self, dataset, class_n, transform):
super().__init__()
self.dataset = []
self.class_n = class_n
self.transform = transform
self.class_indexes = [[] for _ in range(self.class_n)]
for i, (x, y) in enumerate(dataset):
self.class_indexes[y].append(i)
self.dataset.append([x, y])
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
data = self.dataset[index]
pos_index = choice(self.class_indexes[data[1]])
neg_index = choice(self.class_indexes[data[1]-randint(1, self.class_n-1)])
pos_data = self.dataset[pos_index]
neg_data = self.dataset[neg_index]
data[0] = self.transform(data[0])
pos_data[0] = self.transform(pos_data[0])
neg_data[0] = self.transform(neg_data[0])
return *data, *pos_data, *neg_data
if __name__ == '__main__':
mnist = MNIST(root="./mnist", download=True)
triplet_dataset = TripletDataset(dataset=mnist, class_n=10, transform=TF.Compose([]))
print(triplet_dataset[0])