From 305c6a028ad2e7195ae22d881472c1e8652b57cb Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 2 Aug 2020 10:30:43 -0700 Subject: [PATCH] compute_loss() leaf variable update --- utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index b71ed35..a8be00d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -439,7 +439,7 @@ class BCEBlurWithLogitsLoss(nn.Module): def compute_loss(p, targets, model): # predictions, targets, model device = targets.device - lcls, lbox, lobj = torch.zeros(3, 1, device=device) + lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets h = model.hyp # hyperparameters @@ -482,13 +482,13 @@ def compute_loss(p, targets, model): # predictions, targets, model if model.nc > 1: # cls loss (only if multiple classes) t = torch.full_like(ps[:, 5:], cn, device=device) # targets t[range(n), tcls[i]] = cp - lcls = lcls + BCEcls(ps[:, 5:], t) # BCE + lcls += BCEcls(ps[:, 5:], t) # BCE # Append targets to text file # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] - lobj = lobj + BCEobj(pi[..., 4], tobj) * balance[i] # obj loss + lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss s = 3 / np # output count scaling lbox *= h['giou'] * s