You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

237 lines
8.2 KiB

2 years ago
# -*- coding: utf-8 -*-
"""
@File : models.py
@Author: csc
@Date : 2022/6/23
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
from collections import defaultdict
from utils import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
class VGG(nn.Module):
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3",
'22': "relu4_3"
}
for p in self.parameters():
p.requires_grad = False
def forward(self, x):
outs = []
for name, module in self.features._modules.items():
x = module(x)
if name in self.layer_name_mapping:
outs.append(x)
return outs
class VGG19(nn.Module):
def __init__(self, features):
super(VGG19, self).__init__()
self.features = features
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'17': "relu3_4",
'26': "relu4_4"
}
for p in self.parameters():
p.requires_grad = False
def forward(self, x):
outs = []
for name, module in self.features._modules.items():
x = module(x)
if name in self.layer_name_mapping:
outs.append(x)
return outs
class MyConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
super(MyConv2D, self).__init__()
self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device)
self.bias = torch.zeros(out_channels).to(device)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size)
self.stride = (stride, stride)
self.groups = groups
def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.stride, groups=self.groups)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
return s.format(**self.__dict__)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv = nn.Sequential(
*ConvLayer(channels, channels, kernel_size=3, stride=1),
*ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False)
)
def forward(self, x):
return self.conv(x) + x
class ResNeXtBlock(nn.Module):
def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1, cardinality=32, width_per_group=64):
super(ResNeXtBlock, self).__init__()
width = int(out_channels * (width_per_group / 64)) * cardinality # 转换通道数
# self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, padding=0) # 不改变尺寸
self.conv1 = MyConv2D(in_channels, width, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality) # stride=2尺寸减半stride=1尺寸不变
# self.conv2 = MyConv2D(width, width, kernel_size=3, stride=stride, groups=cardinality)
self.bn2 = nn.BatchNorm2d(width)
# self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, padding=0) # 不改变尺寸
self.conv3 = MyConv2D(width, out_channels, kernel_size=1, stride=1)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.identity_downsample = identity_downsample
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
# 残差连接
# print(x.shape)
# print(identity.shape)
x += identity
x = self.relu(x)
return x
def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1,
upsample=None, instance_norm=True, relu=True, trainable=False):
layers = []
if upsample:
layers.append(nn.Upsample(mode='nearest', scale_factor=upsample))
layers.append(nn.ReflectionPad2d(kernel_size // 2))
if trainable:
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride))
else:
layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride))
if instance_norm:
layers.append(nn.InstanceNorm2d(out_channels))
if relu:
layers.append(nn.ReLU())
return layers
class TransformNet(nn.Module):
def __init__(self, base=8, residuals='resnet'):
super(TransformNet, self).__init__()
self.base = base
self.weights = []
self.downsampling = nn.Sequential(
*ConvLayer(3, base, kernel_size=9, trainable=True),
*ConvLayer(base, base * 2, kernel_size=3, stride=2),
*ConvLayer(base * 2, base * 4, kernel_size=3, stride=2),
)
if residuals == 'resnet':
self.residuals = nn.Sequential(*[ResidualBlock(base * 4) for i in range(5)])
elif residuals == 'resnext':
self.residuals = nn.Sequential(ResNeXtBlock(base * 4, base * 4))
self.upsampling = nn.Sequential(
*ConvLayer(base * 4, base * 2, kernel_size=3, upsample=2),
*ConvLayer(base * 2, base, kernel_size=3, upsample=2),
*ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True),
)
self.get_param_dict()
def forward(self, X):
y = self.downsampling(X)
y = self.residuals(y)
y = self.upsampling(y)
return y
def get_param_dict(self):
"""找出该网络所有 MyConv2D 层,计算它们需要的权值数量"""
param_dict = defaultdict(int)
def dfs(module, name):
for name2, layer in module.named_children():
dfs(layer, '%s.%s' % (name, name2) if name != '' else name2)
if module.__class__ == MyConv2D:
param_dict[name] += int(np.prod(module.weight.shape))
param_dict[name] += int(np.prod(module.bias.shape))
dfs(self, '')
return param_dict
def set_my_attr(self, name, value):
# 下面这个循环是一步步遍历类似 residuals.0.conv.1 的字符串,找到相应的权值
target = self
for x in name.split('.'):
if x.isnumeric():
target = target.__getitem__(int(x))
else:
target = getattr(target, x)
# 设置对应的权值
n_weight = np.prod(target.weight.shape)
target.weight = value[:n_weight].view(target.weight.shape)
target.bias = value[n_weight:].view(target.bias.shape)
def set_weights(self, weights, i=0):
"""输入权值字典,对该网络所有的 MyConv2D 层设置权值"""
for name, param in weights.items():
self.set_my_attr(name, weights[name][i])
class MetaNet(nn.Module):
def __init__(self, param_dict, backbone='vgg16'):
super(MetaNet, self).__init__()
self.param_num = len(param_dict)
if backbone == 'vgg16':
self.hidden = nn.Linear(1920, 128 * self.param_num)
elif backbone == 'vgg19':
self.hidden = nn.Linear(1920, 128 * self.param_num)
self.fc_dict = {}
for i, (name, params) in enumerate(param_dict.items()):
self.fc_dict[name] = i
setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params))
def forward(self, mean_std_features):
hidden = F.relu(self.hidden(mean_std_features))
filters = {}
for name, i in self.fc_dict.items():
fc = getattr(self, 'fc{}'.format(i + 1))
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
return filters