Skip to content

Commit 7ee913f

Browse files
committed
perf(loss): 批量处理
1 parent d4dd4c6 commit 7ee913f

File tree

1 file changed

+168
-84
lines changed

1 file changed

+168
-84
lines changed

py/lib/models/multi_part_loss.py

Lines changed: 168 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import numpy as np
1111
import torch
1212
import torch.nn as nn
13+
from torch.utils.data import DataLoader
14+
import torchvision.transforms as transforms
15+
16+
from models.location_dataset import LocationDataset
17+
from models.yolo_v1 import YOLO_v1
1318

1419

1520
class MultiPartLoss(nn.Module):
@@ -32,87 +37,135 @@ def forward(self, preds, targets):
3237
:param targets: (N, S*S, (B*5+C))
3338
:return:
3439
"""
35-
# ## 预测
36-
# # 提取每个网格的分类概率
37-
# pred_probs = preds[-1, :self.S * self.S * self.C].reshape(-1, self.S, self.S, self.C)
38-
# # 提取每个网格的置信度
39-
# pred_confidences = preds[-1, self.S * self.S * self.C: self.S * self.S * (self.B + self.C)] \
40-
# .reshape(-1, self.S, self.S, self.B)
41-
# # 提取每个网格的预测边界框坐标
42-
# pred_bboxs = preds[-1, self.S * self.S * (self.B + self.C): self.S * self.S * (self.B * 5 + self.C)] \
43-
# .reshape(-1, self.S, self.S, 4)
40+
N = preds.shape[0]
41+
## 预测
42+
# 提取每个网格的分类概率
43+
# [N, S*S, C] -> [N*S*S, C]
44+
pred_probs = preds[:, :, :self.C].reshape(-1, self.C)
45+
# 提取每个网格的置信度
46+
# [N, S*S, B] -> [N*S*S, B]
47+
pred_confidences = preds[:, :, self.C: (self.B + self.C)].reshape(-1, self.B)
48+
# 提取每个网格的预测边界框坐标
49+
# [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
50+
pred_bboxs = preds[:, :, (self.B + self.C): (self.B * 5 + self.C)] \
51+
.reshape(-1, self.B * 4) \
52+
.reshape(-1, self.B, 4)
53+
54+
## 目标
55+
# 提取每个网格的分类概率
56+
# [N, S*S, C] -> [N*S*S, C]
57+
target_probs = targets[:, :, :self.C].reshape(-1, self.C)
58+
# 提取每个网格的置信度
59+
# [N, S*S, B] -> [N*S*S, B]
60+
target_confidences = targets[:, :, self.C: (self.B + self.C)].reshape(-1, self.B)
61+
# 提取每个网格的边界框坐标
62+
# [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
63+
target_bboxs = targets[:, :, (self.B + self.C): (self.B * 5 + self.C)] \
64+
.reshape(-1, self.B * 4) \
65+
.reshape(-1, self.B, 4)
66+
67+
## 首先计算所有边界框的置信度损失(假定不存在obj)
68+
loss = self.noobj * self.sum_squared_error(pred_confidences, target_confidences)
69+
70+
# 选取每个网格中置信度最高的边界框
71+
top_idxs = torch.argmax(pred_confidences, dim=1)
72+
top_len = len(top_idxs)
73+
# 获取相应的置信度以及边界框
74+
top_pred_confidences = pred_confidences[range(top_len), top_idxs]
75+
top_pred_bboxs = pred_bboxs[range(top_len), top_idxs]
76+
77+
top_target_confidences = target_confidences[range(top_len), top_idxs]
78+
top_target_bboxs = target_bboxs[range(top_len), top_idxs]
79+
print(top_pred_confidences.shape)
80+
print(top_pred_bboxs.shape)
81+
82+
# 选取存在目标的网格
83+
obj_idxs = torch.sum(target_probs, dim=1) == 1
84+
print(obj_idxs)
85+
86+
obj_pred_confidences = top_pred_confidences[obj_idxs]
87+
obj_pred_bboxs = top_pred_bboxs[obj_idxs]
88+
obj_pred_probs = pred_probs[obj_idxs]
89+
90+
obj_target_confidences = top_target_confidences[obj_idxs]
91+
obj_target_bboxs = top_target_bboxs[obj_idxs]
92+
obj_target_probs = target_probs[obj_idxs]
93+
94+
## 计算目标边界框的置信度损失
95+
loss += (1 - self.noobj) * self.sum_squared_error(obj_pred_confidences, obj_target_confidences)
96+
## 计算分类概率损失
97+
loss += self.sum_squared_error(obj_pred_probs, obj_target_probs)
98+
## 计算边界框坐标损失
99+
loss += self.sum_squared_error(obj_pred_bboxs[:, :2], obj_target_bboxs[:, :2])
100+
loss += self.sum_squared_error(torch.sqrt(obj_pred_bboxs[:, 2:]), torch.sqrt(obj_target_bboxs[:, 2:]))
101+
102+
return loss / N
103+
104+
# N = preds.shape[0]
105+
# total_loss = 0.0
106+
# print(preds.shape)
107+
# print(targets.shape)
108+
# for pred, target in zip(preds, targets):
109+
# """
110+
# 逐个图像计算
111+
# pred: [S*S, (B*5+C)]
112+
# target: [S*S, (B*5+C)]
113+
# """
114+
# # 分类概率
115+
# pred_probs = pred[:, :self.C]
116+
# target_probs = target[:, :self.C]
117+
# # 置信度
118+
# pred_confidences = pred[:, self.C:(self.C + self.B)]
119+
# target_confidences = target[:, self.C:(self.C + self.B)]
120+
# # 边界框坐标
121+
# pred_bboxs = pred[:, (self.C + self.B):]
122+
# target_bboxs = target[:, (self.C + self.B):]
44123
#
45-
# ## 目标
46-
# # 每个网格的分类
47-
# target_probs = targets[-1, :self.S * self.S].reshape(-1, self.S, self.S)
48-
# # 置信度
49-
# target_confidences = targets[-1, self.S * self.S: self.S * self.S * 2].reshape(-1, self.S, self.S)
50-
# # 坐标
51-
# target_bboxs = targets[-1, self.S * self.S * 2:self.S * self.S * 6].reshape(-1, self.S, self.S, 4)
124+
# for i in range(self.S * self.S):
125+
# """
126+
# 逐个网格计算
127+
# """
128+
# pred_single_probs = pred_probs[i]
129+
# target_single_probs = target_probs[i]
52130
#
53-
# # 图像中哪些网格包含了目标(根据分类判断)
54-
# objs = torch.where(target_probs != -1)
55-
# # 哪些不包含目标
56-
# nobjs = torch.where(target_probs == -1)
131+
# pred_single_confidences = pred_confidences[i]
132+
# target_single_confidences = target_confidences[i]
57133
#
58-
# ## 首先计算包含了分类的
59-
60-
N = preds.shape[0]
61-
total_loss = 0.0
62-
for pred, target in zip(preds, targets):
63-
"""
64-
逐个图像计算
65-
pred: [S*S, (B*5+C)]
66-
target: [S*S, (B*5+C)]
67-
"""
68-
# 分类概率
69-
pred_probs = pred[:, :self.C]
70-
target_probs = target[:, :self.C]
71-
# 置信度
72-
pred_confidences = pred[:, self.C:(self.C + self.B)]
73-
target_confidences = target[:, self.C:(self.C + self.B)]
74-
# 边界框坐标
75-
pred_bboxs = pred[:, (self.C + self.B):]
76-
target_bboxs = target[:, (self.C + self.B):]
77-
78-
for i in range(self.S * self.S):
79-
"""
80-
逐个网格计算
81-
"""
82-
pred_single_probs = pred_probs[i]
83-
target_single_probs = target_probs[i]
84-
85-
pred_single_confidences = pred_confidences[i]
86-
target_single_confidences = target_confidences[i]
87-
88-
pred_single_bboxs = pred_bboxs[i]
89-
target_single_bboxs = target_bboxs[i]
90-
91-
# 是否存在置信度(如果存在,则target的置信度必然大于0)
92-
is_obj = target_single_confidences[0] > 0
93-
# 计算置信度损失 假定该网格不存在对象
94-
total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
95-
if is_obj:
96-
# 如果存在
97-
# 计算分类损失
98-
total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
99-
100-
# 计算所有预测边界框和标注边界框的IoU
101-
pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
102-
target_single_bboxs = target_single_bboxs.reshape(-1, 4)
103-
104-
scores = self.iou(pred_single_bboxs, target_single_bboxs)
105-
# 提取IoU最大的下标
106-
bbox_idx = torch.argmax(scores)
107-
# 计算置信度损失
108-
total_loss += (1 - self.noobj) * \
109-
self.sum_squared_error(pred_single_confidences[bbox_idx],
110-
target_single_confidences[bbox_idx])
111-
# 计算边界框损失
112-
total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
113-
target_single_bboxs[bbox_idx].reshape(-1, 4))
114-
115-
return total_loss / N
134+
# pred_single_bboxs = pred_bboxs[i]
135+
# target_single_bboxs = target_bboxs[i]
136+
#
137+
# # 是否存在置信度(如果存在,则target的置信度必然大于0)
138+
# is_obj = target_single_confidences[0] > 0
139+
# # 计算置信度损失 假定该网格不存在对象
140+
# total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
141+
# print(total_loss)
142+
# if is_obj:
143+
# print('i = %d' % (i))
144+
# # 如果存在
145+
# # 计算分类损失
146+
# total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
147+
# print(total_loss)
148+
#
149+
# # 计算所有预测边界框和标注边界框的IoU
150+
# pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
151+
# target_single_bboxs = target_single_bboxs.reshape(-1, 4)
152+
#
153+
# scores = self.iou(pred_single_bboxs, target_single_bboxs)
154+
# # 提取IoU最大的下标
155+
# bbox_idx = torch.argmax(scores)
156+
# # 计算置信度损失
157+
# total_loss += (1 - self.noobj) * \
158+
# self.sum_squared_error(pred_single_confidences[bbox_idx],
159+
# target_single_confidences[bbox_idx])
160+
# print(total_loss)
161+
# # 计算边界框损失
162+
# total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
163+
# target_single_bboxs[bbox_idx].reshape(-1, 4))
164+
# print(total_loss)
165+
#
166+
# print('done')
167+
#
168+
# return total_loss / N
116169

117170
def sum_squared_error(self, preds, targets):
118171
return torch.sum((preds - targets) ** 2)
@@ -155,11 +208,42 @@ def iou(self, pred_boxs, target_boxs):
155208
return torch.from_numpy(scores)
156209

157210

158-
if __name__ == '__main__':
159-
criterion = MultiPartLoss(S=7, B=2, C=3)
211+
def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
212+
transform = transforms.Compose([
213+
transforms.ToPILImage(),
214+
transforms.Resize((448, 448)),
215+
transforms.ToTensor(),
216+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
217+
])
218+
219+
data_set = LocationDataset(data_root_dir, cate_list, transform=transform, S=S, B=B, C=C)
220+
data_loader = DataLoader(data_set, batch_size=1, num_workers=8)
160221

161-
preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
162-
targets = torch.ones((1, 7 * 7, 13)) * 0.01
222+
return data_loader
163223

164-
loss = criterion.forward(preds, targets)
165-
print(loss)
224+
225+
if __name__ == '__main__':
226+
S = 7
227+
B = 2
228+
C = 3
229+
cate_list = ['cucumber', 'eggplant', 'mushroom']
230+
231+
criterion = MultiPartLoss(S=7, B=2, C=3)
232+
# preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
233+
# targets = torch.ones((1, 7 * 7, 13)) * 0.01
234+
# loss = criterion(preds, targets)
235+
# print(loss)
236+
data_loader = load_data('../../data/location_dataset', cate_list, S=S, B=B, C=C)
237+
model = YOLO_v1(S=S, B=B, C=C)
238+
239+
for inputs, labels in data_loader:
240+
inputs = inputs
241+
labels = labels
242+
print(inputs.shape)
243+
print(labels.shape)
244+
245+
with torch.set_grad_enabled(False):
246+
outputs = model(inputs)
247+
loss = criterion(outputs, labels)
248+
print(loss)
249+
exit(0)

0 commit comments

Comments
 (0)