diff --git a/models/yolo.py b/models/yolo.py index 9179c85..e491e64 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -61,8 +61,9 @@ class Model(nn.Module): # Build strides, anchors m = self.model[-1] # Detect() - m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward + m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) self.stride = m.stride # Init weights, biases diff --git a/utils/utils.py b/utils/utils.py index 8457643..f1f5db5 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -58,7 +58,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): print('\nAnalyzing anchors... ', end='') m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) - wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh + scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale + wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh def metric(k): # compute metric r = wh[:, None] / k[None] @@ -77,12 +78,23 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss + check_anchor_order(m) print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') else: print('Original anchors better than new anchors. Proceeding with original anchors.') print('') # newline +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) + + def check_file(file): # Searches for file if not found locally if os.path.isfile(file):