You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
2.0 KiB
57 lines
2.0 KiB
import torch,pdb
|
|
import torchvision
|
|
import torch.nn.modules
|
|
|
|
class vgg16bn(torch.nn.Module):
|
|
def __init__(self,pretrained = False):
|
|
super(vgg16bn,self).__init__()
|
|
model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children())
|
|
model = model[:33]+model[34:43]
|
|
self.model = torch.nn.Sequential(*model)
|
|
|
|
def forward(self,x):
|
|
return self.model(x)
|
|
class resnet(torch.nn.Module):
|
|
def __init__(self,layers,pretrained = False):
|
|
super(resnet,self).__init__()
|
|
if layers == '18':
|
|
model = torchvision.models.resnet18(pretrained=pretrained)
|
|
elif layers == '34':
|
|
model = torchvision.models.resnet34(pretrained=pretrained)
|
|
elif layers == '50':
|
|
model = torchvision.models.resnet50(pretrained=pretrained)
|
|
elif layers == '101':
|
|
model = torchvision.models.resnet101(pretrained=pretrained)
|
|
elif layers == '152':
|
|
model = torchvision.models.resnet152(pretrained=pretrained)
|
|
elif layers == '50next':
|
|
model = torchvision.models.resnext50_32x4d(pretrained=pretrained)
|
|
elif layers == '101next':
|
|
model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
|
|
elif layers == '50wide':
|
|
model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
|
|
elif layers == '101wide':
|
|
model = torchvision.models.wide_resnet101_2(pretrained=pretrained)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
self.conv1 = model.conv1
|
|
self.bn1 = model.bn1
|
|
self.relu = model.relu
|
|
self.maxpool = model.maxpool
|
|
self.layer1 = model.layer1
|
|
self.layer2 = model.layer2
|
|
self.layer3 = model.layer3
|
|
self.layer4 = model.layer4
|
|
|
|
def forward(self,x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
x2 = self.layer2(x)
|
|
x3 = self.layer3(x2)
|
|
x4 = self.layer4(x3)
|
|
return x2,x3,x4
|