1010import numpy as np
1111import torch
1212import torch .nn as nn
13+ from torch .utils .data import DataLoader
14+ import torchvision .transforms as transforms
15+
16+ from models .location_dataset import LocationDataset
17+ from models .yolo_v1 import YOLO_v1
1318
1419
1520class MultiPartLoss (nn .Module ):
@@ -32,87 +37,135 @@ def forward(self, preds, targets):
3237 :param targets: (N, S*S, (B*5+C))
3338 :return:
3439 """
35- # ## 预测
36- # # 提取每个网格的分类概率
37- # pred_probs = preds[-1, :self.S * self.S * self.C].reshape(-1, self.S, self.S, self.C)
38- # # 提取每个网格的置信度
39- # pred_confidences = preds[-1, self.S * self.S * self.C: self.S * self.S * (self.B + self.C)] \
40- # .reshape(-1, self.S, self.S, self.B)
41- # # 提取每个网格的预测边界框坐标
42- # pred_bboxs = preds[-1, self.S * self.S * (self.B + self.C): self.S * self.S * (self.B * 5 + self.C)] \
43- # .reshape(-1, self.S, self.S, 4)
40+ N = preds .shape [0 ]
41+ ## 预测
42+ # 提取每个网格的分类概率
43+ # [N, S*S, C] -> [N*S*S, C]
44+ pred_probs = preds [:, :, :self .C ].reshape (- 1 , self .C )
45+ # 提取每个网格的置信度
46+ # [N, S*S, B] -> [N*S*S, B]
47+ pred_confidences = preds [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
48+ # 提取每个网格的预测边界框坐标
49+ # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
50+ pred_bboxs = preds [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
51+ .reshape (- 1 , self .B * 4 ) \
52+ .reshape (- 1 , self .B , 4 )
53+
54+ ## 目标
55+ # 提取每个网格的分类概率
56+ # [N, S*S, C] -> [N*S*S, C]
57+ target_probs = targets [:, :, :self .C ].reshape (- 1 , self .C )
58+ # 提取每个网格的置信度
59+ # [N, S*S, B] -> [N*S*S, B]
60+ target_confidences = targets [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
61+ # 提取每个网格的边界框坐标
62+ # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
63+ target_bboxs = targets [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
64+ .reshape (- 1 , self .B * 4 ) \
65+ .reshape (- 1 , self .B , 4 )
66+
67+ ## 首先计算所有边界框的置信度损失(假定不存在obj)
68+ loss = self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
69+
70+ # 选取每个网格中置信度最高的边界框
71+ top_idxs = torch .argmax (pred_confidences , dim = 1 )
72+ top_len = len (top_idxs )
73+ # 获取相应的置信度以及边界框
74+ top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
75+ top_pred_bboxs = pred_bboxs [range (top_len ), top_idxs ]
76+
77+ top_target_confidences = target_confidences [range (top_len ), top_idxs ]
78+ top_target_bboxs = target_bboxs [range (top_len ), top_idxs ]
79+ print (top_pred_confidences .shape )
80+ print (top_pred_bboxs .shape )
81+
82+ # 选取存在目标的网格
83+ obj_idxs = torch .sum (target_probs , dim = 1 ) == 1
84+ print (obj_idxs )
85+
86+ obj_pred_confidences = top_pred_confidences [obj_idxs ]
87+ obj_pred_bboxs = top_pred_bboxs [obj_idxs ]
88+ obj_pred_probs = pred_probs [obj_idxs ]
89+
90+ obj_target_confidences = top_target_confidences [obj_idxs ]
91+ obj_target_bboxs = top_target_bboxs [obj_idxs ]
92+ obj_target_probs = target_probs [obj_idxs ]
93+
94+ ## 计算目标边界框的置信度损失
95+ loss += (1 - self .noobj ) * self .sum_squared_error (obj_pred_confidences , obj_target_confidences )
96+ ## 计算分类概率损失
97+ loss += self .sum_squared_error (obj_pred_probs , obj_target_probs )
98+ ## 计算边界框坐标损失
99+ loss += self .sum_squared_error (obj_pred_bboxs [:, :2 ], obj_target_bboxs [:, :2 ])
100+ loss += self .sum_squared_error (torch .sqrt (obj_pred_bboxs [:, 2 :]), torch .sqrt (obj_target_bboxs [:, 2 :]))
101+
102+ return loss / N
103+
104+ # N = preds.shape[0]
105+ # total_loss = 0.0
106+ # print(preds.shape)
107+ # print(targets.shape)
108+ # for pred, target in zip(preds, targets):
109+ # """
110+ # 逐个图像计算
111+ # pred: [S*S, (B*5+C)]
112+ # target: [S*S, (B*5+C)]
113+ # """
114+ # # 分类概率
115+ # pred_probs = pred[:, :self.C]
116+ # target_probs = target[:, :self.C]
117+ # # 置信度
118+ # pred_confidences = pred[:, self.C:(self.C + self.B)]
119+ # target_confidences = target[:, self.C:(self.C + self.B)]
120+ # # 边界框坐标
121+ # pred_bboxs = pred[:, (self.C + self.B):]
122+ # target_bboxs = target[:, (self.C + self.B):]
44123 #
45- # ## 目标
46- # # 每个网格的分类
47- # target_probs = targets[-1, :self.S * self.S].reshape(-1, self.S, self.S)
48- # # 置信度
49- # target_confidences = targets[-1, self.S * self.S: self.S * self.S * 2].reshape(-1, self.S, self.S)
50- # # 坐标
51- # target_bboxs = targets[-1, self.S * self.S * 2:self.S * self.S * 6].reshape(-1, self.S, self.S, 4)
124+ # for i in range(self.S * self.S):
125+ # """
126+ # 逐个网格计算
127+ # """
128+ # pred_single_probs = pred_probs[i]
129+ # target_single_probs = target_probs[i]
52130 #
53- # # 图像中哪些网格包含了目标(根据分类判断)
54- # objs = torch.where(target_probs != -1)
55- # # 哪些不包含目标
56- # nobjs = torch.where(target_probs == -1)
131+ # pred_single_confidences = pred_confidences[i]
132+ # target_single_confidences = target_confidences[i]
57133 #
58- # ## 首先计算包含了分类的
59-
60- N = preds .shape [0 ]
61- total_loss = 0.0
62- for pred , target in zip (preds , targets ):
63- """
64- 逐个图像计算
65- pred: [S*S, (B*5+C)]
66- target: [S*S, (B*5+C)]
67- """
68- # 分类概率
69- pred_probs = pred [:, :self .C ]
70- target_probs = target [:, :self .C ]
71- # 置信度
72- pred_confidences = pred [:, self .C :(self .C + self .B )]
73- target_confidences = target [:, self .C :(self .C + self .B )]
74- # 边界框坐标
75- pred_bboxs = pred [:, (self .C + self .B ):]
76- target_bboxs = target [:, (self .C + self .B ):]
77-
78- for i in range (self .S * self .S ):
79- """
80- 逐个网格计算
81- """
82- pred_single_probs = pred_probs [i ]
83- target_single_probs = target_probs [i ]
84-
85- pred_single_confidences = pred_confidences [i ]
86- target_single_confidences = target_confidences [i ]
87-
88- pred_single_bboxs = pred_bboxs [i ]
89- target_single_bboxs = target_bboxs [i ]
90-
91- # 是否存在置信度(如果存在,则target的置信度必然大于0)
92- is_obj = target_single_confidences [0 ] > 0
93- # 计算置信度损失 假定该网格不存在对象
94- total_loss += self .noobj * self .sum_squared_error (pred_single_confidences , target_single_confidences )
95- if is_obj :
96- # 如果存在
97- # 计算分类损失
98- total_loss += self .sum_squared_error (pred_single_probs , target_single_probs )
99-
100- # 计算所有预测边界框和标注边界框的IoU
101- pred_single_bboxs = pred_single_bboxs .reshape (- 1 , 4 )
102- target_single_bboxs = target_single_bboxs .reshape (- 1 , 4 )
103-
104- scores = self .iou (pred_single_bboxs , target_single_bboxs )
105- # 提取IoU最大的下标
106- bbox_idx = torch .argmax (scores )
107- # 计算置信度损失
108- total_loss += (1 - self .noobj ) * \
109- self .sum_squared_error (pred_single_confidences [bbox_idx ],
110- target_single_confidences [bbox_idx ])
111- # 计算边界框损失
112- total_loss += self .coord * self .bbox_loss (pred_single_bboxs [bbox_idx ].reshape (- 1 , 4 ),
113- target_single_bboxs [bbox_idx ].reshape (- 1 , 4 ))
114-
115- return total_loss / N
134+ # pred_single_bboxs = pred_bboxs[i]
135+ # target_single_bboxs = target_bboxs[i]
136+ #
137+ # # 是否存在置信度(如果存在,则target的置信度必然大于0)
138+ # is_obj = target_single_confidences[0] > 0
139+ # # 计算置信度损失 假定该网格不存在对象
140+ # total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
141+ # print(total_loss)
142+ # if is_obj:
143+ # print('i = %d' % (i))
144+ # # 如果存在
145+ # # 计算分类损失
146+ # total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
147+ # print(total_loss)
148+ #
149+ # # 计算所有预测边界框和标注边界框的IoU
150+ # pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
151+ # target_single_bboxs = target_single_bboxs.reshape(-1, 4)
152+ #
153+ # scores = self.iou(pred_single_bboxs, target_single_bboxs)
154+ # # 提取IoU最大的下标
155+ # bbox_idx = torch.argmax(scores)
156+ # # 计算置信度损失
157+ # total_loss += (1 - self.noobj) * \
158+ # self.sum_squared_error(pred_single_confidences[bbox_idx],
159+ # target_single_confidences[bbox_idx])
160+ # print(total_loss)
161+ # # 计算边界框损失
162+ # total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
163+ # target_single_bboxs[bbox_idx].reshape(-1, 4))
164+ # print(total_loss)
165+ #
166+ # print('done')
167+ #
168+ # return total_loss / N
116169
117170 def sum_squared_error (self , preds , targets ):
118171 return torch .sum ((preds - targets ) ** 2 )
@@ -155,11 +208,42 @@ def iou(self, pred_boxs, target_boxs):
155208 return torch .from_numpy (scores )
156209
157210
158- if __name__ == '__main__' :
159- criterion = MultiPartLoss (S = 7 , B = 2 , C = 3 )
211+ def load_data (data_root_dir , cate_list , S = 7 , B = 2 , C = 20 ):
212+ transform = transforms .Compose ([
213+ transforms .ToPILImage (),
214+ transforms .Resize ((448 , 448 )),
215+ transforms .ToTensor (),
216+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
217+ ])
218+
219+ data_set = LocationDataset (data_root_dir , cate_list , transform = transform , S = S , B = B , C = C )
220+ data_loader = DataLoader (data_set , batch_size = 1 , num_workers = 8 )
160221
161- preds = torch .arange (637 ).reshape (1 , 7 * 7 , 13 ) * 0.01
162- targets = torch .ones ((1 , 7 * 7 , 13 )) * 0.01
222+ return data_loader
163223
164- loss = criterion .forward (preds , targets )
165- print (loss )
224+
225+ if __name__ == '__main__' :
226+ S = 7
227+ B = 2
228+ C = 3
229+ cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
230+
231+ criterion = MultiPartLoss (S = 7 , B = 2 , C = 3 )
232+ # preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
233+ # targets = torch.ones((1, 7 * 7, 13)) * 0.01
234+ # loss = criterion(preds, targets)
235+ # print(loss)
236+ data_loader = load_data ('../../data/location_dataset' , cate_list , S = S , B = B , C = C )
237+ model = YOLO_v1 (S = S , B = B , C = C )
238+
239+ for inputs , labels in data_loader :
240+ inputs = inputs
241+ labels = labels
242+ print (inputs .shape )
243+ print (labels .shape )
244+
245+ with torch .set_grad_enabled (False ):
246+ outputs = model (inputs )
247+ loss = criterion (outputs , labels )
248+ print (loss )
249+ exit (0 )
0 commit comments