From 1aa2b679333657cc20a702dabb1b5de3315cf577 Mon Sep 17 00:00:00 2001 From: yxNONG <62932917+yxNONG@users.noreply.github.com> Date: Thu, 2 Jul 2020 13:51:52 +0800 Subject: [PATCH] Update train.py --- train.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/train.py b/train.py index d933a5d..3b7c9a5 100644 --- a/train.py +++ b/train.py @@ -147,15 +147,6 @@ def train(hyp): # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) - # Initialize distributed training - if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): - dist.init_process_group(backend='nccl', # distributed backend - init_method='tcp://127.0.0.1:9999', # init method - world_size=1, # number of nodes - rank=0) # node rank - model = torch.nn.parallel.DistributedDataParallel(model) - # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html - # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) @@ -173,6 +164,15 @@ def train(hyp): model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.names = data_dict['names'] + + # Initialize distributed training + if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): + dist.init_process_group(backend='nccl', # distributed backend + init_method='tcp://127.0.0.1:9999', # init method + world_size=1, # number of nodes + rank=0) # node rank + model = torch.nn.parallel.DistributedDataParallel(model) + # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html # Class frequency labels = np.concatenate(dataset.labels, 0) @@ -289,7 +289,7 @@ def train(hyp): batch_size=batch_size, imgsz=imgsz_test, save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), - model=ema.ema, + model=ema.ema.module if hasattr(model, 'module') else ema.ema, single_cls=opt.single_cls, dataloader=testloader) @@ -315,14 +315,6 @@ def train(hyp): # Save model save = (not opt.nosave) or (final_epoch and not opt.evolve) if save: - if hasattr(model, 'module'): - # Duplicate Model parameters for Multi-GPU save - ema.ema.module.nc = model.nc # attach number of classes to model - ema.ema.module.hyp = model.hyp # attach hyperparameters to model - ema.ema.module.gr = model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) - ema.ema.module.class_weights = model.class_weights # attach class weights - ema.ema.module.names = data_dict['names'] - with open(results_file, 'r') as f: # create checkpoint ckpt = {'epoch': epoch, 'best_fitness': best_fitness,