From 93684531c6e71547667ee19df6ddb94af3c8c80d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 6 Aug 2020 22:26:38 -0700 Subject: [PATCH] train.py --logdir argparser addition (#660) * train.py --logdir argparser addition * train.py --logdir argparser addition --- train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index 3074e51..3dacb42 100644 --- a/train.py +++ b/train.py @@ -55,20 +55,20 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) def train(hyp, opt, device, tb_writer=None): print(f'Hyperparameters {hyp}') - log_dir = tb_writer.log_dir if tb_writer else 'runs/evolve' # run directory - wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory + log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory + wdir = str(log_dir / 'weights') + os.sep # weights directory os.makedirs(wdir, exist_ok=True) last = wdir + 'last.pt' best = wdir + 'best.pt' - results_file = log_dir + os.sep + 'results.txt' + results_file = str(log_dir / 'results.txt') epochs, batch_size, total_batch_size, weights, rank = \ opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank # TODO: Use DDP logging. Only the first process is allowed to log. # Save run settings - with open(Path(log_dir) / 'hyp.yaml', 'w') as f: + with open(log_dir / 'hyp.yaml', 'w') as f: yaml.dump(hyp, f, sort_keys=False) - with open(Path(log_dir) / 'opt.yaml', 'w') as f: + with open(log_dir / 'opt.yaml', 'w') as f: yaml.dump(vars(opt), f, sort_keys=False) # Configure @@ -325,7 +325,7 @@ def train(hyp, opt, device, tb_writer=None): # Plot if ni < 3: - f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename + f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) if tb_writer and result is not None: tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) @@ -433,7 +433,8 @@ if __name__ == '__main__': parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') - parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') + parser.add_argument('--local-rank', type=int, default=-1, help='DDP parameter, do not modify') + parser.add_argument('--logdir', type=str, default='runs/', help='logging directory') opt = parser.parse_args() # Resume @@ -472,8 +473,8 @@ if __name__ == '__main__': if not opt.evolve: tb_writer = None if opt.global_rank in [-1, 0]: - print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') - tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) + print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir) + tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp train(hyp, opt, device, tb_writer)