diff --git a/utils/utils.py b/utils/utils.py index 8a4aae2..e7e0ed3 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -599,7 +599,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero().t() + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).t() x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # best class only conf, j = x[:, 5:].max(1, keepdim=True)