From de191655e49fd5ed7b9af5b9ffaf99b4d63f9c92 Mon Sep 17 00:00:00 2001 From: Alex Stoken Date: Wed, 24 Jun 2020 17:21:54 -0500 Subject: [PATCH] Fix yaml saving (don't sort keys), reorder --opt keys, bug fix hyp dict accessor --- train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index a0b358e..654a6bf 100644 --- a/train.py +++ b/train.py @@ -91,7 +91,7 @@ def train(hyp): else: pg0.append(v) # all else - if hyp.optimizer =='adam': + if hyp['optimizer'] =='adam': optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) #use default beta2, adjust beta1 for Adam momentum per momentum adjustments in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR else: optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) @@ -190,10 +190,10 @@ def train(hyp): #save hyperparamter and training options in run folder with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f: - yaml.dump(hyp, f) + yaml.dump(hyp, f, sort_keys=False) with open(os.path.join(log_dir, 'opt.yaml'), 'w') as f: - yaml.dump(vars(opt), f) + yaml.dump(vars(opt), f, sort_keys=False) # Class frequency labels = np.concatenate(dataset.labels, 0) @@ -370,10 +370,11 @@ def train(hyp): if __name__ == '__main__': check_git_status() parser = argparse.ArgumentParser() - parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model cfg path[*.yaml]') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data cfg path [*.yaml]') + parser.add_argument('--hyp', type=str, default='',help='hyp cfg path [*.yaml].') + parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes. Assumes square imgs.') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') @@ -386,7 +387,7 @@ if __name__ == '__main__': parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') - parser.add_argument('--hyp', type=str, default='', help ='hyp cfg path [*.yaml].') + opt = parser.parse_args() opt.cfg = check_file(opt.cfg) # check file