From 25e51bcec723eb0ff094824a0f89ac726a5ee701 Mon Sep 17 00:00:00 2001 From: Alex Stoken Date: Tue, 16 Jun 2020 15:50:27 -0500 Subject: [PATCH] add util function to get most recent last.pt file added logic in train.py __main__ to handle resuming from a run --- train.py | 13 ++++++++++--- utils/utils.py | 6 ++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index df5e1ed..25cf3d4 100644 --- a/train.py +++ b/train.py @@ -198,10 +198,10 @@ def train(hyp): model.names = data_dict['names'] #save hyperparamter and training options in run folder - with open(os.path.join(log_dir, 'hyp.yaml', 'w')) as f: + with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f: yaml.dump(hyp, f) - with open(os.path.join(log_dir, 'opt.yaml', 'w')) as f: + with open(os.path.join(log_dir, 'opt.yaml'), 'w') as f: yaml.dump(opt, f) # Class frequency @@ -294,7 +294,7 @@ def train(hyp): # Plot if ni < 3: - f = 'train_batch%g.jpg' % i # filename + f = os.path.join(log_dir, 'train_batch%g.jpg' % i) # filename res = plot_images(images=imgs, targets=targets, paths=paths, fname=f) if tb_writer: tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch) @@ -385,6 +385,7 @@ if __name__ == '__main__': parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--resume', action='store_true', help='resume training from last.pt') + parser.add_argument('--resume_from_run', type=str, default='', 'resume training from last.pt in this dir') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') @@ -398,6 +399,12 @@ if __name__ == '__main__': parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file') opt = parser.parse_args() + + if opt.resume and not opt.resume_from_run: + last = get_latest_run() + print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}') + else: + last = opt.resume_from_run opt.weights = last if opt.resume else opt.weights opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file diff --git a/utils/utils.py b/utils/utils.py index 8ac73e3..56fb66b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -36,6 +36,12 @@ def init_seeds(seed=0): np.random.seed(seed) torch_utils.init_seeds(seed=seed) +def get_latest_run(search_dir = './runs/'): + # get path to most recent 'last.pt' in run dirs + # assumes most recently saved 'last.pt' is the desired weights to --resume from + last_list = glob.glob('runs/*/last.pt') + latest = max(last_list, key = os.path.getctime) + return latest def check_git_status(): # Suggest 'git pull' if repo is out of date