improved model.yaml source tracking

pull/1/head
Glenn Jocher 5 years ago
parent c80b249e67
commit 02445d176d

@ -128,7 +128,7 @@ def detect(save_img=False):
if save_txt or save_img: if save_txt or save_img:
print('Results saved to %s' % os.getcwd() + os.sep + out) print('Results saved to %s' % os.getcwd() + os.sep + out)
if platform == 'darwin': # MacOS if platform == 'darwin' and not opt.update: # MacOS
os.system('open ' + save_path) os.system('open ' + save_path)
print('Done. (%.3fs)' % (time.time() - t0)) print('Done. (%.3fs)' % (time.time() - t0))

@ -1,4 +1,5 @@
import argparse import argparse
from copy import deepcopy
from models.experimental import * from models.experimental import *
@ -43,20 +44,21 @@ class Detect(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
super(Model, self).__init__() super(Model, self).__init__()
if type(model_cfg) is dict: if isinstance(cfg, dict):
self.md = model_cfg # model dict self.yaml = cfg # model dict
else: # is *.yaml else: # is *.yaml
import yaml # for torch hub import yaml # for torch hub
with open(model_cfg) as f: self.yaml_file = Path(cfg).name
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
# Define model # Define model
if nc and nc != self.md['nc']: if nc and nc != self.yaml['nc']:
print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc)) print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
self.md['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# Build strides, anchors # Build strides, anchors
@ -148,17 +150,21 @@ class Model(nn.Module):
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
m.bn = None # remove batchnorm m.bn = None # remove batchnorm
m.forward = m.fuseforward # update forward m.forward = m.fuseforward # update forward
torch_utils.model_info(self) self.info()
return self return self
def parse_model(md, ch): # model_dict, input_channels(3) def info(self): # print model information
torch_utils.model_info(self)
def parse_model(d, ch): # model_dict, input_channels(3)
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = md['anchors'], md['nc'], md['depth_multiple'], md['width_multiple'] anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) # number of anchors na = (len(anchors[0]) // 2) # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5) no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(md['backbone'] + md['head']): # from, number, module, args for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args): for j, a in enumerate(args):
try: try:

Loading…
Cancel
Save