Skip to content

Commit fad8f85

Browse files
committed
test(yolo_v1): 前向计算时验证输入是否是4维
1 parent 4406158 commit fad8f85

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

py/lib/models/yolo_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def __init__(self, S=7, B=2, C=20):
6767
self.C = C
6868

6969
def forward(self, x):
70+
"""
71+
:param x: [N, C, H, W]
72+
:return:
73+
"""
74+
assert len(x.shape) == 4
7075
N = x.shape[0]
7176

7277
x = self.features(x)
@@ -79,8 +84,8 @@ def forward(self, x):
7984

8085

8186
if __name__ == '__main__':
82-
data = torch.randn((1, 3, 448, 448))
83-
# data = torch.randn((1, 3, 224, 224))
87+
# data = torch.randn((1, 3, 448, 448))
88+
data = torch.randn((1, 3, 224, 224))
8489
model = YOLO_v1(7, 2, 3)
8590

8691
outputs = model(data)

0 commit comments

Comments
 (0)