diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 7eea4b4..e069792 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -7,6 +7,7 @@ import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F +import torchvision.models as models def init_seeds(seed=0): @@ -120,18 +121,22 @@ def model_info(model, verbose=False): def load_classifier(name='resnet101', n=2): # Loads a pretrained model reshaped to n-class output - import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision - model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet') + model = models.__dict__[name](pretrained=True) # Display model properties - for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']: + input_size = [3, 224, 224] + input_space = 'RGB' + input_range = [0, 1] + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + for x in [input_size, input_space, input_range, mean, std]: print(x + ' =', eval(x)) # Reshape output to n classes - filters = model.last_linear.weight.shape[1] - model.last_linear.bias = torch.nn.Parameter(torch.zeros(n)) - model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) - model.last_linear.out_features = n + filters = model.fc.weight.shape[1] + model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True) + model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True) + model.fc.out_features = n return model