diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 3cbec8b..dd2e6e7 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -176,13 +176,13 @@ class ModelEMA: def __init__(self, model, decay=0.9999, device=''): # Create EMA - self.ema = deepcopy(model.module if is_parallel(model) else model).half() # FP16 EMA + self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA self.ema.eval() self.updates = 0 # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) self.device = device # perform ema on different device from model if set if device: - self.ema.to(device=device) + self.ema.to(device) for p in self.ema.parameters(): p.requires_grad_(False)