From b569ed6d6b8c209dbc92d9a1b21dfc5ceea51fdd Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 19 Jul 2020 22:12:55 -0700 Subject: [PATCH] pretrained model loading bug fix (#450) Signed-off-by: Glenn Jocher --- train.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index ad82b85..1b901ee 100644 --- a/train.py +++ b/train.py @@ -1,13 +1,12 @@ import argparse -import torch import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data -from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter import test # import test.py to get mAP after each epoch from models.yolo import Model @@ -61,7 +60,7 @@ def train(hyp, tb_writer, opt, device): yaml.dump(vars(opt), f, sort_keys=False) epochs = opt.epochs # 300 - batch_size = opt.batch_size # batch size per process. + batch_size = opt.batch_size # batch size per process. total_batch_size = opt.total_batch_size weights = opt.weights # initial training weights local_rank = opt.local_rank @@ -70,7 +69,7 @@ def train(hyp, tb_writer, opt, device): # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs. # Configure - init_seeds(2+local_rank) + init_seeds(2 + local_rank) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict train_path = data_dict['train'] @@ -131,7 +130,8 @@ def train(hyp, tb_writer, opt, device): # load model try: - ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() if k in model.state_dict()} + ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() + if k in model.state_dict() and model.state_dict()[k].shape == v.shape} model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ @@ -187,7 +187,8 @@ def train(hyp, tb_writer, opt, device): # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, - cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size) + cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, + world_size=opt.world_size) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) @@ -195,8 +196,8 @@ def train(hyp, tb_writer, opt, device): # Testloader if local_rank in [-1, 0]: # local_rank is set to -1. Because only the first process is expected to do evaluation. - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, - cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, + cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset @@ -242,7 +243,8 @@ def train(hyp, tb_writer, opt, device): if local_rank in [-1, 0]: w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) - dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx + dataset.indices = random.choices(range(dataset.n), weights=image_weights, + k=dataset.n) # rand weighted idx # Broadcast. if local_rank != -1: indices = torch.zeros([dataset.n], dtype=torch.int) @@ -402,7 +404,7 @@ def train(hyp, tb_writer, opt, device): plot_results() # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) - dist.destroy_process_group() if local_rank not in [-1,0] else None + dist.destroy_process_group() if local_rank not in [-1, 0] else None torch.cuda.empty_cache() return results @@ -431,7 +433,8 @@ if __name__ == '__main__': parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.") # Parameter For DDP. - parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") + parser.add_argument('--local_rank', type=int, default=-1, + help="Extra parameter for DDP implementation. Don't use it manually.") opt = parser.parse_args() last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run