diff --git a/utils/utils.py b/utils/utils.py index 659c939..23223ed 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -494,7 +494,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c continue # Compute conf - x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) @@ -502,10 +502,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c # Detections matrix nx6 (xyxy, conf, cls) if multi_label: i, j = (x[:, 5:] > conf_thres).nonzero().t() - x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # best class only - conf, j = x[:, 5:].max(1) - x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres] + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] # Filter by class if classes: @@ -524,8 +524,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c # x = x[x[:, 4].argsort(descending=True)] # Batched NMS - c = x[:, 5] * 0 if agnostic else x[:, 5] # classes - boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) if i.shape[0] > max_det: # limit detections i = i[:max_det]