diff --git a/train.py b/train.py index 0cc3f31..aabf4f1 100644 --- a/train.py +++ b/train.py @@ -79,7 +79,6 @@ def train(hyp): # Create model model = Model(opt.cfg).to(device) assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc']) - model.names = data_dict['names'] # Image sizes gs = int(max(model.stride)) # grid size (max stride) @@ -178,6 +177,7 @@ def train(hyp): model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + model.names = data_dict['names'] # Class frequency labels = np.concatenate(dataset.labels, 0) @@ -294,7 +294,7 @@ def train(hyp): batch_size=batch_size, imgsz=imgsz_test, save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), - model=ema.ema, + model=ema.ema.module if hasattr(model, 'module') else ema.ema, single_cls=opt.single_cls, dataloader=testloader) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e069792..a62adc9 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -54,6 +54,11 @@ def time_synchronized(): return time.time() +def is_parallel(model): + # is model is parallel with DP or DDP + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + + def initialize_weights(model): for m in model.modules(): t = type(m) @@ -111,8 +116,8 @@ def model_info(model, verbose=False): try: # FLOPS from thop import profile - macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False) - fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) + flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2 + fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS except: fs = '' @@ -185,7 +190,7 @@ class ModelEMA: self.updates += 1 d = self.decay(self.updates) with torch.no_grad(): - if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): + if is_parallel(model): msd, esd = model.module.state_dict(), self.ema.module.state_dict() else: msd, esd = model.state_dict(), self.ema.state_dict() @@ -196,7 +201,8 @@ class ModelEMA: v += (1. - d) * msd[k].detach() def update_attr(self, model): - # Assign attributes (which may change during training) - for k in model.__dict__.keys(): - if not k.startswith('_'): - setattr(self.ema, k, getattr(model, k)) + # Update class attributes + ema = self.ema.module if is_parallel(model) else self.ema + for k, v in model.__dict__.items(): + if not k.startswith('_') and k != 'module': + setattr(ema, k, v)