From 12b0c046d534b18ea586bb0d273d868cf16002f8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 7 Jun 2020 13:42:33 -0700 Subject: [PATCH] model fusion and onnx export --- models/common.py | 3 +++ models/onnx_export.py | 15 ++++++++------- models/yolo.py | 9 +++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index b92eee3..5f5c502 100644 --- a/models/common.py +++ b/models/common.py @@ -20,6 +20,9 @@ class Conv(nn.Module): # standard convolution def forward(self, x): return self.act(self.bn(self.conv(x))) + def fuseforward(self, x): + return self.act(self.conv(x)) + class Bottleneck(nn.Module): def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion diff --git a/models/onnx_export.py b/models/onnx_export.py index f1bd18c..4591059 100644 --- a/models/onnx_export.py +++ b/models/onnx_export.py @@ -1,6 +1,6 @@ -# Exports a pytorch *.pt model to *.onnx format. Example usage (run from ./yolov5 directory): -# $ export PYTHONPATH="$PWD" -# $ python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 +# Exports a pytorch *.pt model to *.onnx format +# Example usage (run from ./yolov5 directory): +# $ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 import argparse @@ -10,10 +10,11 @@ from models.common import * if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--weights', default='./weights/yolov5s.pt', help='weights path') - parser.add_argument('--img-size', default=640, help='inference size (pixels)') - parser.add_argument('--batch-size', default=1, help='batch size') + parser.add_argument('--weights', type=str, default='./weights/yolov5s.pt', help='weights path') + parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--batch-size', type=int, default=1, help='batch size') opt = parser.parse_args() + print(opt) # Parameters f = opt.weights.replace('.pt', '.onnx') # onnx filename @@ -23,7 +24,7 @@ if __name__ == '__main__': google_utils.attempt_download(opt.weights) model = torch.load(opt.weights)['model'] model.eval() - # model.fuse() # optionally fuse Conv2d + BatchNorm2d layers TODO + # model.fuse() # Export to onnx model.model[-1].export = True # set Detect() layer export=True diff --git a/models/yolo.py b/models/yolo.py index f731290..27960de 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -123,6 +123,15 @@ class Model(nn.Module): b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean())) + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print('Fusing layers...') + for m in self.model.modules(): + if type(m) is Conv: + m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv + m.bn = None # remove batchnorm + m.forward = m.fuseforward # update forward + torch_utils.model_info(self) + def parse_model(md, ch): # model_dict, input_channels(3) print('\n%3s%15s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))