You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.0 KiB

import torch,pdb
import torchvision
import torch.nn.modules
class vgg16bn(torch.nn.Module):
def __init__(self,pretrained = False):
super(vgg16bn,self).__init__()
model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children())
model = model[:33]+model[34:43]
self.model = torch.nn.Sequential(*model)
def forward(self,x):
return self.model(x)
class resnet(torch.nn.Module):
def __init__(self,layers,pretrained = False):
super(resnet,self).__init__()
if layers == '18':
model = torchvision.models.resnet18(pretrained=pretrained)
elif layers == '34':
model = torchvision.models.resnet34(pretrained=pretrained)
elif layers == '50':
model = torchvision.models.resnet50(pretrained=pretrained)
elif layers == '101':
model = torchvision.models.resnet101(pretrained=pretrained)
elif layers == '152':
model = torchvision.models.resnet152(pretrained=pretrained)
elif layers == '50next':
model = torchvision.models.resnext50_32x4d(pretrained=pretrained)
elif layers == '101next':
model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
elif layers == '50wide':
model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
elif layers == '101wide':
model = torchvision.models.wide_resnet101_2(pretrained=pretrained)
else:
raise NotImplementedError
self.conv1 = model.conv1
self.bn1 = model.bn1
self.relu = model.relu
self.maxpool = model.maxpool
self.layer1 = model.layer1
self.layer2 = model.layer2
self.layer3 = model.layer3
self.layer4 = model.layer4
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x2 = self.layer2(x)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
return x2,x3,x4