|
23 | 23 | B = 2 |
24 | 24 | C = 3 |
25 | 25 |
|
| 26 | +cate_list = ['cucumber', 'eggplant', 'mushroom'] |
26 | 27 |
|
27 | | -def load_data(data_root_dir, S=7, B=2, C=20): |
| 28 | + |
| 29 | +def load_data(data_root_dir, cate_list, S=7, B=2, C=20): |
28 | 30 | transform = transforms.Compose([ |
29 | 31 | transforms.ToPILImage(), |
30 | 32 | transforms.Resize((448, 448)), |
31 | 33 | transforms.ToTensor(), |
32 | 34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
33 | 35 | ]) |
34 | 36 |
|
35 | | - data_set = LocationDataset(data_root_dir, transform=transform, S=S, B=B, C=C) |
| 37 | + data_set = LocationDataset(data_root_dir, cate_list, transform=transform, S=S, B=B, C=C) |
36 | 38 | data_loader = DataLoader(data_set, batch_size=8, shuffle=True, num_workers=8) |
37 | 39 |
|
38 | 40 | return data_loader |
@@ -108,7 +110,7 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc |
108 | 110 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
109 | 111 | # device = "cpu" |
110 | 112 |
|
111 | | - data_loader = load_data('../data/training_images', S=S, B=B, C=C) |
| 113 | + data_loader = load_data('../data/training_images', cate_list, S=S, B=B, C=C) |
112 | 114 | # print(len(data_loader)) |
113 | 115 |
|
114 | 116 | model = YOLO_v1(S=S, B=B, C=C) |
|
0 commit comments