|  |  |  | @ -173,22 +173,23 @@ def scale_img(img, ratio=1.0, same_shape=False):  # img(16,3,256,416), r=ratio | 
			
		
	
		
			
				
					|  |  |  |  |     return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def copy_attr(a, b, include=(), exclude=()): | 
			
		
	
		
			
				
					|  |  |  |  |     # Copy attributes from b to a, options to only include [...] and to exclude [...] | 
			
		
	
		
			
				
					|  |  |  |  |     for k, v in b.__dict__.items(): | 
			
		
	
		
			
				
					|  |  |  |  |         if (len(include) and k not in include) or k.startswith('_') or k in exclude: | 
			
		
	
		
			
				
					|  |  |  |  |             continue | 
			
		
	
		
			
				
					|  |  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |  |             setattr(a, k, v) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | class ModelEMA: | 
			
		
	
		
			
				
					|  |  |  |  |     """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models | 
			
		
	
		
			
				
					|  |  |  |  |     Keep a moving average of everything in the model state_dict (parameters and buffers). | 
			
		
	
		
			
				
					|  |  |  |  |     This is intended to allow functionality like | 
			
		
	
		
			
				
					|  |  |  |  |     https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage | 
			
		
	
		
			
				
					|  |  |  |  |     A smoothed version of the weights is necessary for some training schemes to perform well. | 
			
		
	
		
			
				
					|  |  |  |  |     E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use | 
			
		
	
		
			
				
					|  |  |  |  |     RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA | 
			
		
	
		
			
				
					|  |  |  |  |     smoothing of weights to match results. Pay attention to the decay constant you are using | 
			
		
	
		
			
				
					|  |  |  |  |     relative to your update count per epoch. | 
			
		
	
		
			
				
					|  |  |  |  |     To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but | 
			
		
	
		
			
				
					|  |  |  |  |     disable validation of the EMA weights. Validation will have to be done manually in a separate | 
			
		
	
		
			
				
					|  |  |  |  |     process, or after the training stops converging. | 
			
		
	
		
			
				
					|  |  |  |  |     This class is sensitive where it is initialized in the sequence of model init, | 
			
		
	
		
			
				
					|  |  |  |  |     GPU assignment and distributed training wrappers. | 
			
		
	
		
			
				
					|  |  |  |  |     I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, model, decay=0.9999, updates=0): | 
			
		
	
	
		
			
				
					|  |  |  | @ -211,8 +212,6 @@ class ModelEMA: | 
			
		
	
		
			
				
					|  |  |  |  |                     v *= d | 
			
		
	
		
			
				
					|  |  |  |  |                     v += (1. - d) * msd[k].detach() | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def update_attr(self, model): | 
			
		
	
		
			
				
					|  |  |  |  |     def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): | 
			
		
	
		
			
				
					|  |  |  |  |         # Update EMA attributes | 
			
		
	
		
			
				
					|  |  |  |  |         for k, v in model.__dict__.items(): | 
			
		
	
		
			
				
					|  |  |  |  |             if not k.startswith('_') and k not in ["process_group", "reducer"]: | 
			
		
	
		
			
				
					|  |  |  |  |                 setattr(self.ema, k, v) | 
			
		
	
		
			
				
					|  |  |  |  |         copy_attr(self.ema, model, include, exclude) | 
			
		
	
	
		
			
				
					|  |  |  | 
 |