|  |  | @ -6,6 +6,9 @@ Usage: | 
			
		
	
		
		
			
				
					
					|  |  |  | """ |  |  |  | """ | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | dependencies = ['torch', 'yaml'] |  |  |  | dependencies = ['torch', 'yaml'] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | import os | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | import torch |  |  |  | import torch | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | from models.yolo import Model |  |  |  | from models.yolo import Model | 
			
		
	
	
		
		
			
				
					|  |  | @ -24,11 +27,12 @@ def create(name, pretrained, channels, classes): | 
			
		
	
		
		
			
				
					
					|  |  |  |     Returns: |  |  |  |     Returns: | 
			
		
	
		
		
			
				
					
					|  |  |  |         pytorch model |  |  |  |         pytorch model | 
			
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     model = Model('models/%s.yaml' % name, channels, classes) |  |  |  |     config = os.path.join(os.path.dirname(__file__), 'models', '%s.yaml' % name)  # model.yaml path | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     model = Model(config, channels, classes) | 
			
		
	
		
		
			
				
					
					|  |  |  |     if pretrained: |  |  |  |     if pretrained: | 
			
		
	
		
		
			
				
					
					|  |  |  |         ckpt = '%s.pt' % name  # checkpoint filename |  |  |  |         ckpt = '%s.pt' % name  # checkpoint filename | 
			
		
	
		
		
			
				
					
					|  |  |  |         google_utils.attempt_download(ckpt)  # download if not found locally |  |  |  |         google_utils.attempt_download(ckpt)  # download if not found locally | 
			
		
	
		
		
			
				
					
					|  |  |  |         state_dict = torch.load(ckpt)['model'].state_dict() |  |  |  |         state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].state_dict() | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |         state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()}  # filter |  |  |  |         state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()}  # filter | 
			
		
	
		
		
			
				
					
					|  |  |  |         model.load_state_dict(state_dict, strict=False)  # load |  |  |  |         model.load_state_dict(state_dict, strict=False)  # load | 
			
		
	
		
		
			
				
					
					|  |  |  |     return model |  |  |  |     return model | 
			
		
	
	
		
		
			
				
					|  |  | 
 |