From 02445d176d3a8756fd5f2800b66137beeac9e573 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 13 Jul 2020 14:35:47 -0700 Subject: [PATCH] improved model.yaml source tracking --- detect.py | 2 +- models/yolo.py | 32 +++++++++++++++++++------------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/detect.py b/detect.py index 5c9577b..bfe4604 100644 --- a/detect.py +++ b/detect.py @@ -128,7 +128,7 @@ def detect(save_img=False): if save_txt or save_img: print('Results saved to %s' % os.getcwd() + os.sep + out) - if platform == 'darwin': # MacOS + if platform == 'darwin' and not opt.update: # MacOS os.system('open ' + save_path) print('Done. (%.3fs)' % (time.time() - t0)) diff --git a/models/yolo.py b/models/yolo.py index 3fd87a3..3a34a28 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,4 +1,5 @@ import argparse +from copy import deepcopy from models.experimental import * @@ -43,20 +44,21 @@ class Detect(nn.Module): class Model(nn.Module): - def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes + def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes super(Model, self).__init__() - if type(model_cfg) is dict: - self.md = model_cfg # model dict + if isinstance(cfg, dict): + self.yaml = cfg # model dict else: # is *.yaml import yaml # for torch hub - with open(model_cfg) as f: - self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict + self.yaml_file = Path(cfg).name + with open(cfg) as f: + self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict # Define model - if nc and nc != self.md['nc']: - print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc)) - self.md['nc'] = nc # override yaml value - self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out + if nc and nc != self.yaml['nc']: + print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc)) + self.yaml['nc'] = nc # override yaml value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # Build strides, anchors @@ -148,17 +150,21 @@ class Model(nn.Module): m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv m.bn = None # remove batchnorm m.forward = m.fuseforward # update forward - torch_utils.model_info(self) + self.info() return self -def parse_model(md, ch): # model_dict, input_channels(3) + def info(self): # print model information + torch_utils.model_info(self) + + +def parse_model(d, ch): # model_dict, input_channels(3) print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) - anchors, nc, gd, gw = md['anchors'], md['nc'], md['depth_multiple'], md['width_multiple'] + anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] na = (len(anchors[0]) // 2) # number of anchors no = na * (nc + 5) # number of outputs = anchors * (classes + 5) layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out - for i, (f, n, m, args) in enumerate(md['backbone'] + md['head']): # from, number, module, args + for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args m = eval(m) if isinstance(m, str) else m # eval strings for j, a in enumerate(args): try: