From 9fdb0fbacf65ad6d8778da592ef9b87038a3f4c7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Jun 2020 19:51:15 -0700 Subject: [PATCH] AutoAnchor bug fix # 117 --- train.py | 4 +++- utils/utils.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 7df99c0..5163bb1 100644 --- a/train.py +++ b/train.py @@ -200,7 +200,8 @@ def train(hyp): tb_writer.add_histogram('classes', c, 0) # Check anchors - check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) + if not opt.noautoanchor: + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average ema = torch_utils.ModelEMA(model) @@ -374,6 +375,7 @@ if __name__ == '__main__': parser.add_argument('--resume', action='store_true', help='resume training from last.pt') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--notest', action='store_true', help='only test final epoch') + parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') diff --git a/utils/utils.py b/utils/utils.py index 02aa6b6..47f5219 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -56,7 +56,7 @@ def check_img_size(img_size, s=32): def check_anchors(dataset, model, thr=4.0, imgsz=640): # Check anchor fit to data, recompute if necessary print('\nAnalyzing anchors... ', end='') - anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid + m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh @@ -66,14 +66,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): best = x.max(1)[0] # best_x return (best > 1. / thr).float().mean() #  best possible recall - bpr = metric(anchors.clone().cpu().view(-1, 2)) + bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2)) print('Best Possible Recall (BPR) = %.4f' % bpr, end='') if bpr < 0.99: # threshold to recompute print('. Attempting to generate improved anchors, please wait...' % bpr) - new_anchors = kmean_anchors(dataset, n=anchors.numel() // 2, img_size=imgsz, thr=thr, gen=1000, verbose=False) + na = m.anchor_grid.numel() // 2 # number of anchors + new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) new_bpr = metric(new_anchors.reshape(-1, 2)) - if new_bpr > bpr: - anchors[:] = torch.tensor(new_anchors).view_as(anchors).type_as(anchors) + if new_bpr > bpr: # replace anchors + new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) + m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference + m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') else: print('Original anchors better than new anchors. Proceeding with original anchors.')