diff --git a/train.py b/train.py index e70b53a..e067148 100644 --- a/train.py +++ b/train.py @@ -42,7 +42,6 @@ def train(hyp, opt, device, tb_writer=None): epochs, batch_size, total_batch_size, weights, rank = \ opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank - # TODO: Use DDP logging. Only the first process is allowed to log. # Save run settings with open(log_dir / 'hyp.yaml', 'w') as f: yaml.dump(hyp, f, sort_keys=False) @@ -130,6 +129,8 @@ def train(hyp, opt, device, tb_writer=None): # Epochs start_epoch = ckpt['epoch'] + 1 + if opt.resume: + assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) if epochs < start_epoch: logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % (weights, ckpt['epoch'], epochs)) @@ -158,19 +159,19 @@ def train(hyp, opt, device, tb_writer=None): model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank)) # Trainloader - dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, - cache=opt.cache_images, rect=opt.rect, rank=rank, + dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, world_size=opt.world_size, workers=opt.workers) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches + ema.updates = start_epoch * nb // accumulate # set EMA updates assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) # Testloader if rank in [-1, 0]: - # local_rank is set to -1. Because only the first process is expected to do evaluation. - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, - cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size, - workers=opt.workers)[0] + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, + hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1, + world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0 # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset @@ -283,7 +284,7 @@ def train(hyp, opt, device, tb_writer=None): scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() - if ema is not None: + if ema: ema.update(model) # Print @@ -305,12 +306,13 @@ def train(hyp, opt, device, tb_writer=None): # end batch ------------------------------------------------------------------------------------------------ # Scheduler + lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard scheduler.step() # DDP process 0 or single-GPU if rank in [-1, 0]: # mAP - if ema is not None: + if ema: ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP @@ -330,10 +332,11 @@ def train(hyp, opt, device, tb_writer=None): # Tensorboard if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', + tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] - for x, tag in zip(list(mloss[:-1]) + list(results), tags): + 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params + for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): tb_writer.add_scalar(tag, x, epoch) # Update best mAP @@ -389,8 +392,7 @@ if __name__ == '__main__': parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--rect', action='store_true', help='rectangular training') - parser.add_argument('--resume', nargs='?', const='get_last', default=False, - help='resume from given path/last.pt, or most recent run if blank') + parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') @@ -413,21 +415,24 @@ if __name__ == '__main__': opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 set_logging(opt.global_rank) - - # Resume - if opt.resume: - last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run - if last and not opt.weights: - logger.info(f'Resuming training from {last}') - opt.weights = last if opt.resume and not opt.weights else opt.weights if opt.global_rank in [-1, 0]: check_git_status() - opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml') - opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files - assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' + # Resume + if opt.resume: # resume an interrupted run + ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path + assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' + with open(Path(ckpt).parent.parent / 'opt.yaml') as f: + opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace + opt.cfg, opt.weights, opt.resume = '', ckpt, True + logger.info('Resuming training from %s' % ckpt) + + else: + opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml') + opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files + assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' + opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) - opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = select_device(opt.device, batch_size=opt.batch_size) # DDP mode diff --git a/utils/general.py b/utils/general.py index 4da75d0..52cb3ad 100755 --- a/utils/general.py +++ b/utils/general.py @@ -61,7 +61,7 @@ def init_seeds(seed=0): def get_latest_run(search_dir='./runs'): # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) - return max(last_list, key=os.path.getctime) + return max(last_list, key=os.path.getctime) if last_list else '' def check_git_status():