Skip to content

Commit 3a38a54

Browse files
committed
perf(detect): 指定device
1 parent 92e0095 commit 3a38a54

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

py/detector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def load_data(img_path, xml_path):
5959
return img, data_dict
6060

6161

62-
def load_model():
63-
model_path = './models/checkpoint_yolo_v1_24.pth'
62+
def load_model(device):
63+
model_path = './models/checkpoint_yolo_v1_49.pth'
6464
model = YOLO_v1(S=7, B=2, C=3)
6565
model.load_state_dict(torch.load(model_path))
6666
model.eval()
@@ -114,7 +114,7 @@ def deform_bboxs(pred_bboxs, data_dict):
114114
# device = "cpu"
115115

116116
img, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml')
117-
model = load_model()
117+
model = load_model(device)
118118
# 计算
119119
outputs = model.forward(img.to(device)).cpu().squeeze(0)
120120
print(outputs.shape)
@@ -143,5 +143,5 @@ def deform_bboxs(pred_bboxs, data_dict):
143143
pred_bboxs = deform_bboxs(pred_cate_bboxs, data_dict)
144144
# 在原图绘制标注边界框和预测边界框
145145
dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], pred_bboxs, pred_cates, pred_cate_probs)
146-
cv2.imwrite('./detect.png', dst)
146+
# cv2.imwrite('./detect.png', dst)
147147
draw.show(dst)

0 commit comments

Comments
 (0)