From bb87276d8088e5ec51959c87da7edbe74bc42841 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 31 Jul 2020 14:34:13 -0700 Subject: [PATCH] update build_targets() (#589) Signed-off-by: Glenn Jocher --- utils/utils.py | 103 +++++++++++++++++++++++-------------------------- 1 file changed, 48 insertions(+), 55 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index e7e0ed3..30a45b2 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -308,7 +308,7 @@ def compute_ap(recall, precision): def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False): # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 - box2 = box2.t() + box2 = box2.T # Get the coordinates of bounding boxes if x1y1x2y2: # x1, y1, x2, y2 = box1 @@ -347,7 +347,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False): v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) with torch.no_grad(): alpha = v / (1 - iou + v + 1e-16) - return iou - (rho2 / c2 + v * alpha ) # CIoU + return iou - (rho2 / c2 + v * alpha) # CIoU return iou @@ -369,8 +369,8 @@ def box_iou(box1, box2): # box = 4xn return (box[2] - box[0]) * (box[3] - box[1]) - area1 = box_area(box1.t()) - area2 = box_area(box2.t()) + area1 = box_area(box1.T) + area2 = box_area(box2.T) # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) @@ -439,70 +439,62 @@ class BCEBlurWithLogitsLoss(nn.Module): def compute_loss(p, targets, model): # predictions, targets, model device = targets.device - ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor - lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device) + lcls, lbox, lobj = torch.zeros(3, 1, device=device) tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets h = model.hyp # hyperparameters - red = 'mean' # Loss reduction (sum or mean) # Define criteria - BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device) - BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device) + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device) - # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 cp, cn = smooth_BCE(eps=0.0) - # focal loss + # Focal loss g = h['fl_gamma'] # focal loss gamma if g > 0: BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) - # per output + # Losses nt = 0 # number of targets np = len(p) # number of outputs balance = [4.0, 1.0, 0.4] if np == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx - tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj + tobj = torch.zeros_like(pi[..., 0], device=device) # target obj - nb = b.shape[0] # number of targets - if nb: - nt += nb # cumulative targets + n = b.shape[0] # number of targets + if n: + nt += n # cumulative targets ps = pi[b, a, gj, gi] # prediction subset corresponding to targets - # GIoU + # Regression pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box - giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target) - lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss + giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target) + lbox += (1.0 - giou).mean() # giou loss - # Obj + # Objectness tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio - # Class + # Classification if model.nc > 1: # cls loss (only if multiple classes) - t = torch.full_like(ps[:, 5:], cn).to(device) # targets - t[range(nb), tcls[i]] = cp - lcls += BCEcls(ps[:, 5:], t) # BCE + t = torch.full_like(ps[:, 5:], cn, device=device) # targets + t[range(n), tcls[i]] = cp + lcls = lcls + BCEcls(ps[:, 5:], t) # BCE # Append targets to text file # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] - lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss + lobj = lobj + BCEobj(pi[..., 4], tobj) * balance[i] # obj loss s = 3 / np # output count scaling lbox *= h['giou'] * s lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) lcls *= h['cls'] * s bs = tobj.shape[0] # batch size - if red == 'sum': - g = 3.0 # loss gain - lobj *= g / bs - if nt: - lcls *= g / nt / model.nc - lbox *= g / nt loss = lbox + lobj + lcls return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() @@ -510,40 +502,40 @@ def compute_loss(p, targets, model): # predictions, targets, model def build_targets(p, targets, model): # Build targets for compute_loss(), input targets(image,class,x,y,w,h) - det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \ - else model.model[-1] # Detect() module + det = model.module.model[-1] if torch_utils.is_parallel(model) else model.model[-1] # Detect() module na, nt = det.na, targets.shape[0] # number of anchors, targets tcls, tbox, indices, anch = [], [], [], [] - gain = torch.ones(6, device=targets.device) # normalized to gridspace gain - off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets - at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt) + gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) + targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor([[0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=targets.device).float() * g # offsets - g = 0.5 # offset - style = 'rect4' for i in range(det.nl): anchors = det.anchors[i] - gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain + gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain # Match targets to anchors - a, t, offsets = [], targets * gain, 0 + t, offsets = targets * gain, 0 if nt: - r = t[None, :, 4:6] / anchors[:, None] # wh ratio + # Matches + r = t[:, :, 4:6] / anchors[:, None] # wh ratio j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare - # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2)) - a, t = at[j], t.repeat(na, 1, 1)[j] # filter + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter - # overlaps + # Offsets gxy = t[:, 2:4] # grid xy - z = torch.zeros_like(gxy) - if style == 'rect2': - j, k = ((gxy % 1. < g) & (gxy > 1.)).T - a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0) - offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g - elif style == 'rect4': - j, k = ((gxy % 1. < g) & (gxy > 1.)).T - l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T - a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0) - offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1. < g) & (gxy > 1.)).T + l, m = ((gxi % 1. < g) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] # Define b, c = t[:, :2].long().T # image, class @@ -553,6 +545,7 @@ def build_targets(p, targets, model): gi, gj = gij.T # grid xy indices # Append + a = t[:, 6].long() # anchor indices indices.append((b, a, gj, gi)) # image, anchor, grid indices tbox.append(torch.cat((gxy - gij, gwh), 1)) # box anch.append(anchors[a]) # anchors @@ -599,7 +592,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).t() + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # best class only conf, j = x[:, 5:].max(1, keepdim=True)