From 8b26e890064a464c253af60480b7c16f3fc75d17 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 16 Jun 2020 00:53:34 -0700 Subject: [PATCH] AutoAnchor bug fix #72 --- train.py | 3 +-- utils/utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 6b3219a..1e2d55a 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler -import yaml from torch.utils.tensorboard import SummaryWriter import test # import test.py to get mAP after each epoch @@ -200,7 +199,7 @@ def train(hyp): tb_writer.add_histogram('classes', c, 0) # Check anchors - check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t']) + check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average ema = torch_utils.ModelEMA(model) diff --git a/utils/utils.py b/utils/utils.py index 9e1c4b1..95d1198 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -52,15 +52,17 @@ def check_img_size(img_size, s=32): return make_divisible(img_size, s) # nearest gs-multiple -def check_best_possible_recall(dataset, anchors, thr): +def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640): # Check best possible recall of dataset with current anchors - wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])).float() # wh + shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) + wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio bpr = (m.min(1)[0] < thr).float().mean() # best possible recall mr = (m < thr).float().mean() # match ratio - print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall')) - print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr)) + print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall')) + print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr)) + assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \ 'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr