From 04bdbe4104728dac15937ad06dbb9071ae3bebf9 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 5 Jul 2020 23:16:50 -0700 Subject: [PATCH] fuse update --- detect.py | 9 +++------ models/yolo.py | 4 ++-- test.py | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/detect.py b/detect.py index 268b5df..44cd64e 100644 --- a/detect.py +++ b/detect.py @@ -21,13 +21,10 @@ def detect(save_img=False): # Load model google_utils.attempt_download(weights) - model = torch.load(weights, map_location=device)['model'].float() # load to FP32 - # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning - # model.fuse() - model.to(device).eval() - imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size + model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model + imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size if half: - model.half() # to FP16 + model.float() # to FP16 # Second-stage classifier classify = False diff --git a/models/yolo.py b/models/yolo.py index 9617f5b..3fd87a3 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -142,14 +142,14 @@ class Model(nn.Module): # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers - print('Fusing layers...') + print('Fusing layers... ', end='') for m in self.model.modules(): if type(m) is Conv: m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv m.bn = None # remove batchnorm m.forward = m.fuseforward # update forward torch_utils.model_info(self) - + return self def parse_model(md, ch): # model_dict, input_channels(3) print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) diff --git a/test.py b/test.py index 259d444..644f6b9 100644 --- a/test.py +++ b/test.py @@ -22,6 +22,7 @@ def test(data, # Initialize/load model and set device if model is None: training = False + merge = opt.merge # use Merge NMS device = torch_utils.select_device(opt.device, batch_size=batch_size) # Remove previous @@ -59,7 +60,6 @@ def test(data, # Dataloader if dataloader is None: # not training - merge = opt.merge # use Merge NMS img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images