From af41083d52992d86da7cc74734b4cf14bdb56e04 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 3 Jul 2020 16:57:08 -0700 Subject: [PATCH] EMA FP16 fix #279 --- utils/torch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 3cbec8b..dd2e6e7 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -176,13 +176,13 @@ class ModelEMA: def __init__(self, model, decay=0.9999, device=''): # Create EMA - self.ema = deepcopy(model.module if is_parallel(model) else model).half() # FP16 EMA + self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA self.ema.eval() self.updates = 0 # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) self.device = device # perform ema on different device from model if set if device: - self.ema.to(device=device) + self.ema.to(device) for p in self.ema.parameters(): p.requires_grad_(False)