AutoAnchor bug fix #72

pull/1/head
Glenn Jocher 5 years ago
parent 8fa3724072
commit 8b26e89006

@ -4,7 +4,6 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
import yaml
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import test # import test.py to get mAP after each epoch import test # import test.py to get mAP after each epoch
@ -200,7 +199,7 @@ def train(hyp):
tb_writer.add_histogram('classes', c, 0) tb_writer.add_histogram('classes', c, 0)
# Check anchors # Check anchors
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t']) check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
# Exponential moving average # Exponential moving average
ema = torch_utils.ModelEMA(model) ema = torch_utils.ModelEMA(model)

@ -52,15 +52,17 @@ def check_img_size(img_size, s=32):
return make_divisible(img_size, s) # nearest gs-multiple return make_divisible(img_size, s) # nearest gs-multiple
def check_best_possible_recall(dataset, anchors, thr): def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
# Check best possible recall of dataset with current anchors # Check best possible recall of dataset with current anchors
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])).float() # wh shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
bpr = (m.min(1)[0] < thr).float().mean() # best possible recall bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
mr = (m < thr).float().mean() # match ratio mr = (m < thr).float().mean() # match ratio
print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall')) print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr)) print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \ assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr 'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr

Loading…
Cancel
Save