From 333f678b374e7677070ce037ddbe8a655563e8f6 Mon Sep 17 00:00:00 2001 From: Alex Stoken Date: Tue, 16 Jun 2020 16:36:20 -0500 Subject: [PATCH] add update default hyp dict with provided yaml --- train.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index dd0f029..cfc0059 100644 --- a/train.py +++ b/train.py @@ -42,17 +42,6 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) # Don't need to be printing every time #print(hyp) -# Overwrite hyp with hyp*.txt (optional) -if f: - print('Using %s' % f[0]) - for k, v in zip(hyp.keys(), np.loadtxt(f[0])): - hyp[k] = v - -# Print focal loss if gamma > 0 -if hyp['fl_gamma']: - print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) - - def train(hyp): #write all results to the tb log_dir, so all data from one run is together log_dir = tb_writer.log_dir @@ -410,7 +399,7 @@ if __name__ == '__main__': print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}') else: last = '' - + # if resuming, check for hyp file if last: last_hyp = last.replace('last.pt', 'hyp.yaml') @@ -430,7 +419,16 @@ if __name__ == '__main__': # Train if not opt.evolve: tb_writer = SummaryWriter(comment=opt.name) + + #updates hyp defaults from hyp.yaml + if opt.hyp: hyp.update(opt.hyp) + + # Print focal loss if gamma > 0 + if hyp['fl_gamma']: + print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) + print(f'Beginning training with {hyp}\n\n') print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') + train(hyp) # Evolve hyperparameters (optional)