diff --git a/utils/utils.py b/utils/utils.py index 7d8d7a9..9d412bc 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -84,15 +84,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): r = wh[:, None] / k[None] x = torch.min(r, 1. / r).min(2)[0] # ratio metric best = x.max(1)[0] # best_x - return (best > 1. / thr).float().mean() # best possible recall + aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold + bpr = (best > 1. / thr).float().mean() # best possible recall + return bpr, aat - bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2)) - print('Best Possible Recall (BPR) = %.4f' % bpr, end='') - if bpr < 0.99: # threshold to recompute + bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) + print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='') + if bpr < 0.98: # threshold to recompute print('. Attempting to generate improved anchors, please wait...' % bpr) na = m.anchor_grid.numel() // 2 # number of anchors new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) - new_bpr = metric(new_anchors.reshape(-1, 2)) + new_bpr = metric(new_anchors.reshape(-1, 2))[0] if new_bpr > bpr: # replace anchors new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference