train.py --logdir argparser addition (#660)

* train.py --logdir argparser addition

* train.py --logdir argparser addition
pull/1/head
Glenn Jocher 5 years ago committed by GitHub
parent 886b9841c8
commit 93684531c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,20 +55,20 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
def train(hyp, opt, device, tb_writer=None):
print(f'Hyperparameters {hyp}')
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolve' # run directory
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
wdir = str(log_dir / 'weights') + os.sep # weights directory
os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = log_dir + os.sep + 'results.txt'
results_file = str(log_dir / 'results.txt')
epochs, batch_size, total_batch_size, weights, rank = \
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
# TODO: Use DDP logging. Only the first process is allowed to log.
# Save run settings
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
with open(log_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
with open(Path(log_dir) / 'opt.yaml', 'w') as f:
with open(log_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)
# Configure
@ -325,7 +325,7 @@ def train(hyp, opt, device, tb_writer=None):
# Plot
if ni < 3:
f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
@ -433,7 +433,8 @@ if __name__ == '__main__':
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--local-rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
opt = parser.parse_args()
# Resume
@ -472,8 +473,8 @@ if __name__ == '__main__':
if not opt.evolve:
tb_writer = None
if opt.global_rank in [-1, 0]:
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp
train(hyp, opt, device, tb_writer)

Loading…
Cancel
Save