AutoAnchor bug fix # 117

pull/1/head
Glenn Jocher 5 years ago
parent 8b6f5826bc
commit 9fdb0fbacf

@ -200,6 +200,7 @@ def train(hyp):
tb_writer.add_histogram('classes', c, 0) tb_writer.add_histogram('classes', c, 0)
# Check anchors # Check anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
# Exponential moving average # Exponential moving average
@ -374,6 +375,7 @@ if __name__ == '__main__':
parser.add_argument('--resume', action='store_true', help='resume training from last.pt') parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')

@ -56,7 +56,7 @@ def check_img_size(img_size, s=32):
def check_anchors(dataset, model, thr=4.0, imgsz=640): def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary # Check anchor fit to data, recompute if necessary
print('\nAnalyzing anchors... ', end='') print('\nAnalyzing anchors... ', end='')
anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
@ -66,14 +66,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
best = x.max(1)[0] # best_x best = x.max(1)[0] # best_x
return (best > 1. / thr).float().mean() #  best possible recall return (best > 1. / thr).float().mean() #  best possible recall
bpr = metric(anchors.clone().cpu().view(-1, 2)) bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
print('Best Possible Recall (BPR) = %.4f' % bpr, end='') print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
if bpr < 0.99: # threshold to recompute if bpr < 0.99: # threshold to recompute
print('. Attempting to generate improved anchors, please wait...' % bpr) print('. Attempting to generate improved anchors, please wait...' % bpr)
new_anchors = kmean_anchors(dataset, n=anchors.numel() // 2, img_size=imgsz, thr=thr, gen=1000, verbose=False) na = m.anchor_grid.numel() // 2 # number of anchors
new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
new_bpr = metric(new_anchors.reshape(-1, 2)) new_bpr = metric(new_anchors.reshape(-1, 2))
if new_bpr > bpr: if new_bpr > bpr: # replace anchors
anchors[:] = torch.tensor(new_anchors).view_as(anchors).type_as(anchors) new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else: else:
print('Original anchors better than new anchors. Proceeding with original anchors.') print('Original anchors better than new anchors. Proceeding with original anchors.')

Loading…
Cancel
Save