diff --git a/test.py b/test.py index f3d0ec6..18f8e58 100644 --- a/test.py +++ b/test.py @@ -18,8 +18,7 @@ def test(data, verbose=False, model=None, dataloader=None, - fast=False, - save_dir='.', + save_dir='', merge=False): # Initialize/load model and set device @@ -29,7 +28,7 @@ def test(data, device = torch_utils.select_device(opt.device, batch_size=batch_size) # Remove previous - for f in glob.glob(f'{save_dir}/test_batch*.jpg'): + for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): os.remove(f) # Load model @@ -163,10 +162,11 @@ def test(data, # Plot images if batch_i < 1: - f = os.path.join(save_dir, 'test_batch%g_gt.jpg' % batch_i) # filename - plot_images(img, targets, paths, f, names) # ground truth - f = os.path.join(save_dir,'test_batch%g_pred.jpg' % batch_i) - plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions + + f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename + plot_images(img, targets, paths, str(f), names) # ground truth + f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) + plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy