diff --git a/utils/utils.py b/utils/utils.py index 30a45b2..184e8dc 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -493,7 +493,8 @@ def compute_loss(p, targets, model): # predictions, targets, model s = 3 / np # output count scaling lbox *= h['giou'] * s lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) - lcls *= h['cls'] * s + if model.nc > 1: + lcls *= h['cls'] * s bs = tobj.shape[0] # batch size loss = lbox + lobj + lcls