update yolo.py TTA flexibility and extensibility (#506)

* update yolo.py TTA flexibility and extensibility

* Update scale_img()
pull/1/head
Glenn Jocher 5 years ago committed by GitHub
parent 4b5f4806bc
commit 1d17b9af0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -82,18 +82,19 @@ class Model(nn.Module):
def forward(self, x, augment=False, profile=False):
if augment:
img_size = x.shape[-2:] # height, width
s = [0.83, 0.67] # scales
y = []
for i, xi in enumerate((x,
torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
torch_utils.scale_img(x, s[1]), # scale
)):
# cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
y.append(self.forward_once(xi)[0])
y[1][..., :4] /= s[0] # scale
y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
y[2][..., :4] /= s[1] # scale
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi[..., :4] /= si # de-scale
if fi is 2:
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
elif fi is 3:
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
else:
return self.forward_once(x, profile) # single-scale inference, train

@ -164,13 +164,16 @@ def load_classifier(name='resnet101', n=2):
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
# scales img(bs,3,y,x) by ratio
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
gs = 32 # (pixels) grid size
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
if ratio == 1.0:
return img
else:
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
gs = 32 # (pixels) grid size
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def copy_attr(a, b, include=(), exclude=()):

Loading…
Cancel
Save