From ce369053581f4ce7855ea2300170614000682d8e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 30 May 2020 00:12:45 -0700 Subject: [PATCH] updates --- test.py | 2 +- train.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/test.py b/test.py index 3b19447..d04b08b 100644 --- a/test.py +++ b/test.py @@ -256,7 +256,7 @@ if __name__ == '__main__': opt.augment) elif opt.task == 'study': # run over a range of settings and save/plot - for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.p5', 'yolov5x.pt', 'yolov3-spp.pt']: + for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']: f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to x = list(range(256, 1024, 32)) # x axis y = [] # y axis diff --git a/train.py b/train.py index 5dd5bf5..c6c1f7f 100644 --- a/train.py +++ b/train.py @@ -108,30 +108,30 @@ def train(hyp): google_utils.attempt_download(weights) start_epoch, best_fitness = 0, 0.0 if weights.endswith('.pt'): # pytorch format - chkpt = torch.load(weights, map_location=device) + ckpt = torch.load(weights, map_location=device) # load checkpoint # load model try: - chkpt['model'] = \ - {k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} - model.load_state_dict(chkpt['model'], strict=False) + ckpt['model'] = \ + {k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} + model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \ % (opt.weights, opt.cfg, opt.weights) raise KeyError(s) from e # load optimizer - if chkpt['optimizer'] is not None: - optimizer.load_state_dict(chkpt['optimizer']) - best_fitness = chkpt['best_fitness'] + if ckpt['optimizer'] is not None: + optimizer.load_state_dict(ckpt['optimizer']) + best_fitness = ckpt['best_fitness'] # load results - if chkpt.get('training_results') is not None: + if ckpt.get('training_results') is not None: with open(results_file, 'w') as file: - file.write(chkpt['training_results']) # write results.txt + file.write(ckpt['training_results']) # write results.txt - start_epoch = chkpt['epoch'] + 1 - del chkpt + start_epoch = ckpt['epoch'] + 1 + del ckpt # Mixed precision training https://github.com/NVIDIA/apex if mixed_precision: @@ -324,17 +324,17 @@ def train(hyp): save = (not opt.nosave) or (final_epoch and not opt.evolve) if save: with open(results_file, 'r') as f: # create checkpoint - chkpt = {'epoch': epoch, + ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': f.read(), 'model': ema.ema.module if hasattr(model, 'module') else ema.ema, 'optimizer': None if final_epoch else optimizer.state_dict()} # Save last, best and delete - torch.save(chkpt, last) + torch.save(ckpt, last) if (best_fitness == fi) and not final_epoch: - torch.save(chkpt, best) - del chkpt + torch.save(ckpt, best) + del ckpt # end epoch ---------------------------------------------------------------------------------------------------- # end training