|
|
|
@ -56,7 +56,7 @@ def check_img_size(img_size, s=32):
|
|
|
|
|
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
|
|
|
# Check anchor fit to data, recompute if necessary
|
|
|
|
|
print('\nAnalyzing anchors... ', end='')
|
|
|
|
|
anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
|
|
|
|
|
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
|
|
|
|
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
|
|
|
|
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
|
|
|
|
|
|
|
|
@ -66,14 +66,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
|
|
|
best = x.max(1)[0] # best_x
|
|
|
|
|
return (best > 1. / thr).float().mean() # best possible recall
|
|
|
|
|
|
|
|
|
|
bpr = metric(anchors.clone().cpu().view(-1, 2))
|
|
|
|
|
bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
|
|
|
|
|
print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
|
|
|
|
|
if bpr < 0.99: # threshold to recompute
|
|
|
|
|
print('. Attempting to generate improved anchors, please wait...' % bpr)
|
|
|
|
|
new_anchors = kmean_anchors(dataset, n=anchors.numel() // 2, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
|
|
|
na = m.anchor_grid.numel() // 2 # number of anchors
|
|
|
|
|
new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
|
|
|
new_bpr = metric(new_anchors.reshape(-1, 2))
|
|
|
|
|
if new_bpr > bpr:
|
|
|
|
|
anchors[:] = torch.tensor(new_anchors).view_as(anchors).type_as(anchors)
|
|
|
|
|
if new_bpr > bpr: # replace anchors
|
|
|
|
|
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
|
|
|
|
|
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
|
|
|
|
|
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
|
|
|
|
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
|
|
|
|
else:
|
|
|
|
|
print('Original anchors better than new anchors. Proceeding with original anchors.')
|
|
|
|
|