From ce0c58f6781609ad11d99481999c9b9aa87cb91b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 3 Jul 2020 18:56:07 -0700 Subject: [PATCH] update compute_loss() --- utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index d9ffaac..85253a6 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -437,7 +437,8 @@ def compute_loss(p, targets, model): # predictions, targets, model BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) # per output - nt = 0 # targets + nt = 0 # number of targets + np = len(p) # number of outputs balance = [1.0, 1.0, 1.0] for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx @@ -470,7 +471,7 @@ def compute_loss(p, targets, model): # predictions, targets, model lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss - s = 3 / (i + 1) # output count scaling + s = 3 / np # output count scaling lbox *= h['giou'] * s lobj *= h['obj'] * s lcls *= h['cls'] * s @@ -517,7 +518,6 @@ def build_targets(p, targets, model): j, k = ((gxy % 1. < g) & (gxy > 1.)).T a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0) offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g - elif style == 'rect4': j, k = ((gxy % 1. < g) & (gxy > 1.)).T l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T