diff --git a/train.py b/train.py index 39d2392..4943336 100644 --- a/train.py +++ b/train.py @@ -396,7 +396,7 @@ if __name__ == '__main__': # Train if not opt.evolve: print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') - tb_writer = SummaryWriter(comment=opt.name) + tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) if opt.hyp: # update hyps with open(opt.hyp) as f: hyp.update(yaml.load(f, Loader=yaml.FullLoader)) diff --git a/utils/utils.py b/utils/utils.py index 2de2917..486dca6 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -904,6 +904,16 @@ def output_to_target(output, width, height): return np.array(targets) +def increment_dir(dir, comment=''): + # Increments a directory runs/exp1 --> runs/exp2_comment + n = 0 # number + d = sorted(glob.glob(dir + '*')) # directories + if len(d): + d = d[-1].replace(dir, '') + n = int(d[:d.find('_')]) + 1 # increment + return dir + str(n) + ('_' + comment if comment else '') + + # Plotting functions --------------------------------------------------------------------------------------------------- def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy