diff --git a/models/yolo.py b/models/yolo.py index 16638ed..47a35e2 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -82,18 +82,19 @@ class Model(nn.Module): def forward(self, x, augment=False, profile=False): if augment: img_size = x.shape[-2:] # height, width - s = [0.83, 0.67] # scales - y = [] - for i, xi in enumerate((x, - torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale - torch_utils.scale_img(x, s[1]), # scale - )): - # cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) - y.append(self.forward_once(xi)[0]) - - y[1][..., :4] /= s[0] # scale - y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr - y[2][..., :4] /= s[1] # scale + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = torch_utils.scale_img(x.flip(fi) if fi else x, si) + yi = self.forward_once(xi)[0] # forward + # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save + yi[..., :4] /= si # de-scale + if fi is 2: + yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud + elif fi is 3: + yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr + y.append(yi) return torch.cat(y, 1), None # augmented inference, train else: return self.forward_once(x, profile) # single-scale inference, train diff --git a/utils/torch_utils.py b/utils/torch_utils.py index ce584f8..51ef984 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -164,13 +164,16 @@ def load_classifier(name='resnet101', n=2): def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio # scales img(bs,3,y,x) by ratio - h, w = img.shape[2:] - s = (int(h * ratio), int(w * ratio)) # new size - img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize - if not same_shape: # pad/crop img - gs = 32 # (pixels) grid size - h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] - return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean + if ratio == 1.0: + return img + else: + h, w = img.shape[2:] + s = (int(h * ratio), int(w * ratio)) # new size + img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + if not same_shape: # pad/crop img + gs = 32 # (pixels) grid size + h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] + return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean def copy_attr(a, b, include=(), exclude=()):