|
|
@ -7,6 +7,7 @@ import torch
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_seeds(seed=0):
|
|
|
|
def init_seeds(seed=0):
|
|
|
@ -120,18 +121,22 @@ def model_info(model, verbose=False):
|
|
|
|
|
|
|
|
|
|
|
|
def load_classifier(name='resnet101', n=2):
|
|
|
|
def load_classifier(name='resnet101', n=2):
|
|
|
|
# Loads a pretrained model reshaped to n-class output
|
|
|
|
# Loads a pretrained model reshaped to n-class output
|
|
|
|
import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
|
|
|
|
model = models.__dict__[name](pretrained=True)
|
|
|
|
model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Display model properties
|
|
|
|
# Display model properties
|
|
|
|
for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
|
|
|
|
input_size = [3, 224, 224]
|
|
|
|
|
|
|
|
input_space = 'RGB'
|
|
|
|
|
|
|
|
input_range = [0, 1]
|
|
|
|
|
|
|
|
mean = [0.485, 0.456, 0.406]
|
|
|
|
|
|
|
|
std = [0.229, 0.224, 0.225]
|
|
|
|
|
|
|
|
for x in [input_size, input_space, input_range, mean, std]:
|
|
|
|
print(x + ' =', eval(x))
|
|
|
|
print(x + ' =', eval(x))
|
|
|
|
|
|
|
|
|
|
|
|
# Reshape output to n classes
|
|
|
|
# Reshape output to n classes
|
|
|
|
filters = model.last_linear.weight.shape[1]
|
|
|
|
filters = model.fc.weight.shape[1]
|
|
|
|
model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
|
|
|
|
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
|
|
|
|
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
|
|
|
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
|
|
|
|
model.last_linear.out_features = n
|
|
|
|
model.fc.out_features = n
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|