pull/1/head
Glenn Jocher 5 years ago
parent 1e84a23f38
commit ce36905358

@ -256,7 +256,7 @@ if __name__ == '__main__':
opt.augment) opt.augment)
elif opt.task == 'study': # run over a range of settings and save/plot elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.p5', 'yolov5x.pt', 'yolov3-spp.pt']: for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
x = list(range(256, 1024, 32)) # x axis x = list(range(256, 1024, 32)) # x axis
y = [] # y axis y = [] # y axis

@ -108,30 +108,30 @@ def train(hyp):
google_utils.attempt_download(weights) google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
chkpt = torch.load(weights, map_location=device) ckpt = torch.load(weights, map_location=device) # load checkpoint
# load model # load model
try: try:
chkpt['model'] = \ ckpt['model'] = \
{k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} {k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(chkpt['model'], strict=False) model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e: except KeyError as e:
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \ s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
% (opt.weights, opt.cfg, opt.weights) % (opt.weights, opt.cfg, opt.weights)
raise KeyError(s) from e raise KeyError(s) from e
# load optimizer # load optimizer
if chkpt['optimizer'] is not None: if ckpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer']) optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = chkpt['best_fitness'] best_fitness = ckpt['best_fitness']
# load results # load results
if chkpt.get('training_results') is not None: if ckpt.get('training_results') is not None:
with open(results_file, 'w') as file: with open(results_file, 'w') as file:
file.write(chkpt['training_results']) # write results.txt file.write(ckpt['training_results']) # write results.txt
start_epoch = chkpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
del chkpt del ckpt
# Mixed precision training https://github.com/NVIDIA/apex # Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision: if mixed_precision:
@ -324,17 +324,17 @@ def train(hyp):
save = (not opt.nosave) or (final_epoch and not opt.evolve) save = (not opt.nosave) or (final_epoch and not opt.evolve)
if save: if save:
with open(results_file, 'r') as f: # create checkpoint with open(results_file, 'r') as f: # create checkpoint
chkpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': f.read(), 'training_results': f.read(),
'model': ema.ema.module if hasattr(model, 'module') else ema.ema, 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict()} 'optimizer': None if final_epoch else optimizer.state_dict()}
# Save last, best and delete # Save last, best and delete
torch.save(chkpt, last) torch.save(ckpt, last)
if (best_fitness == fi) and not final_epoch: if (best_fitness == fi) and not final_epoch:
torch.save(chkpt, best) torch.save(ckpt, best)
del chkpt del ckpt
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training

Loading…
Cancel
Save