diff --git a/utils/torch_utils.py b/utils/torch_utils.py index dd2e6e7..d697a06 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -201,5 +201,5 @@ class ModelEMA: def update_attr(self, model): # Update EMA attributes for k, v in model.__dict__.items(): - if not k.startswith('_') and k != 'module': + if not k.startswith('_') and k not in ["module", "process_group", "reducer"]: setattr(self.ema, k, v)