remove fast, add merge

pull/1/head
Glenn Jocher 5 years ago
parent 24dd150fbd
commit 1f1917ef56

@ -19,7 +19,7 @@ def test(data,
verbose=False, verbose=False,
model=None, model=None,
dataloader=None, dataloader=None,
fast=False): merge=False):
# Initialize/load model and set device # Initialize/load model and set device
if model is None: if model is None:
training = False training = False
@ -65,7 +65,7 @@ def test(data,
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
fast |= conf_thres > 0.001 # enable fast mode merge = opt.merge # use Merge NMS
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataset = LoadImagesAndLabels(path, dataset = LoadImagesAndLabels(path,
imgsz, imgsz,
@ -109,7 +109,7 @@ def test(data,
# Run NMS # Run NMS
t = torch_utils.time_synchronized() t = torch_utils.time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast) output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
t1 += torch_utils.time_synchronized() - t t1 += torch_utils.time_synchronized() - t
# Statistics per image # Statistics per image
@ -254,6 +254,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--verbose', action='store_true', help='report mAP by class')
opt = parser.parse_args() opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size) opt.img_size = check_img_size(opt.img_size)

@ -305,8 +305,7 @@ def train(hyp):
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema, model=ema.ema,
single_cls=opt.single_cls, single_cls=opt.single_cls,
dataloader=testloader, dataloader=testloader)
fast=epoch < epochs / 2)
# Write # Write
with open(results_file, 'a') as f: with open(results_file, 'a') as f:

@ -527,7 +527,7 @@ def build_targets(p, targets, model):
return tcls, tbox, indices, anch return tcls, tbox, indices, anch
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False): def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
"""Performs Non-Maximum Suppression (NMS) on inference results """Performs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
@ -544,12 +544,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
max_det = 300 # maximum number of detections per image max_det = 300 # maximum number of detections per image
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
fast |= conf_thres > 0.001 # fast mode
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
if fast:
merge = False
else:
merge = True # merge for best mAP (adds 0.5ms/img)
t = time.time() t = time.time()
output = [None] * prediction.shape[0] output = [None] * prediction.shape[0]

Loading…
Cancel
Save