diff --git a/train.py b/train.py index be931ea..77e090f 100644 --- a/train.py +++ b/train.py @@ -163,6 +163,7 @@ def train(hyp): dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class + nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) # Testloader @@ -191,11 +192,10 @@ def train(hyp): check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average - ema = torch_utils.ModelEMA(model) + ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate) # Start training t0 = time.time() - nb = len(dataloader) # number of batches nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 6baa9d5..59f9268 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -191,15 +191,11 @@ class ModelEMA: I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ - def __init__(self, model, decay=0.9999, device=''): + def __init__(self, model, decay=0.9999, updates=0): # Create EMA - self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA - self.ema.eval() - self.updates = 0 # number of EMA updates + self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA + self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) - self.device = device # perform ema on different device from model if set - if device: - self.ema.to(device) for p in self.ema.parameters(): p.requires_grad_(False)