|  |  |  | @ -90,7 +90,7 @@ def prune(model, amount=0.3): | 
			
		
	
		
			
				
					|  |  |  |  |     import torch.nn.utils.prune as prune | 
			
		
	
		
			
				
					|  |  |  |  |     print('Pruning model... ', end='') | 
			
		
	
		
			
				
					|  |  |  |  |     for name, m in model.named_modules(): | 
			
		
	
		
			
				
					|  |  |  |  |         if isinstance(m, torch.nn.Conv2d): | 
			
		
	
		
			
				
					|  |  |  |  |         if isinstance(m, nn.Conv2d): | 
			
		
	
		
			
				
					|  |  |  |  |             prune.l1_unstructured(m, name='weight', amount=amount)  # prune | 
			
		
	
		
			
				
					|  |  |  |  |             prune.remove(m, 'weight')  # make permanent | 
			
		
	
		
			
				
					|  |  |  |  |     print(' %.3g global sparsity' % sparsity(model)) | 
			
		
	
	
		
			
				
					|  |  |  | @ -100,12 +100,12 @@ def fuse_conv_and_bn(conv, bn): | 
			
		
	
		
			
				
					|  |  |  |  |     # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | 
			
		
	
		
			
				
					|  |  |  |  |     with torch.no_grad(): | 
			
		
	
		
			
				
					|  |  |  |  |         # init | 
			
		
	
		
			
				
					|  |  |  |  |         fusedconv = torch.nn.Conv2d(conv.in_channels, | 
			
		
	
		
			
				
					|  |  |  |  |                                     conv.out_channels, | 
			
		
	
		
			
				
					|  |  |  |  |                                     kernel_size=conv.kernel_size, | 
			
		
	
		
			
				
					|  |  |  |  |                                     stride=conv.stride, | 
			
		
	
		
			
				
					|  |  |  |  |                                     padding=conv.padding, | 
			
		
	
		
			
				
					|  |  |  |  |                                     bias=True) | 
			
		
	
		
			
				
					|  |  |  |  |         fusedconv = nn.Conv2d(conv.in_channels, | 
			
		
	
		
			
				
					|  |  |  |  |                               conv.out_channels, | 
			
		
	
		
			
				
					|  |  |  |  |                               kernel_size=conv.kernel_size, | 
			
		
	
		
			
				
					|  |  |  |  |                               stride=conv.stride, | 
			
		
	
		
			
				
					|  |  |  |  |                               padding=conv.padding, | 
			
		
	
		
			
				
					|  |  |  |  |                               bias=True).to(conv.weight.device) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         # prepare filters | 
			
		
	
		
			
				
					|  |  |  |  |         w_conv = conv.weight.clone().view(conv.out_channels, -1) | 
			
		
	
	
		
			
				
					|  |  |  | @ -113,10 +113,7 @@ def fuse_conv_and_bn(conv, bn): | 
			
		
	
		
			
				
					|  |  |  |  |         fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         # prepare spatial bias | 
			
		
	
		
			
				
					|  |  |  |  |         if conv.bias is not None: | 
			
		
	
		
			
				
					|  |  |  |  |             b_conv = conv.bias | 
			
		
	
		
			
				
					|  |  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |  |             b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) | 
			
		
	
		
			
				
					|  |  |  |  |         b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias | 
			
		
	
		
			
				
					|  |  |  |  |         b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) | 
			
		
	
		
			
				
					|  |  |  |  |         fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -159,8 +156,8 @@ def load_classifier(name='resnet101', n=2): | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # Reshape output to n classes | 
			
		
	
		
			
				
					|  |  |  |  |     filters = model.fc.weight.shape[1] | 
			
		
	
		
			
				
					|  |  |  |  |     model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True) | 
			
		
	
		
			
				
					|  |  |  |  |     model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True) | 
			
		
	
		
			
				
					|  |  |  |  |     model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) | 
			
		
	
		
			
				
					|  |  |  |  |     model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) | 
			
		
	
		
			
				
					|  |  |  |  |     model.fc.out_features = n | 
			
		
	
		
			
				
					|  |  |  |  |     return model | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | 
 |