|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
@File : models.py
|
|
|
@Author: csc
|
|
|
@Date : 2022/6/23
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
import torchvision
|
|
|
import torchvision.transforms as transforms
|
|
|
import torchvision.models as models
|
|
|
|
|
|
import numpy as np
|
|
|
from collections import defaultdict
|
|
|
|
|
|
from utils import *
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
# device = 'cpu'
|
|
|
|
|
|
|
|
|
class VGG(nn.Module):
|
|
|
def __init__(self, features):
|
|
|
super(VGG, self).__init__()
|
|
|
self.features = features
|
|
|
self.layer_name_mapping = {
|
|
|
'3': "relu1_2",
|
|
|
'8': "relu2_2",
|
|
|
'15': "relu3_3",
|
|
|
'22': "relu4_3"
|
|
|
}
|
|
|
for p in self.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
def forward(self, x):
|
|
|
outs = []
|
|
|
for name, module in self.features._modules.items():
|
|
|
x = module(x)
|
|
|
if name in self.layer_name_mapping:
|
|
|
outs.append(x)
|
|
|
return outs
|
|
|
|
|
|
|
|
|
class VGG19(nn.Module):
|
|
|
def __init__(self, features):
|
|
|
super(VGG19, self).__init__()
|
|
|
self.features = features
|
|
|
self.layer_name_mapping = {
|
|
|
'3': "relu1_2",
|
|
|
'8': "relu2_2",
|
|
|
'17': "relu3_4",
|
|
|
'26': "relu4_4"
|
|
|
}
|
|
|
for p in self.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
def forward(self, x):
|
|
|
outs = []
|
|
|
for name, module in self.features._modules.items():
|
|
|
x = module(x)
|
|
|
if name in self.layer_name_mapping:
|
|
|
outs.append(x)
|
|
|
return outs
|
|
|
|
|
|
|
|
|
class MyConv2D(nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
|
|
|
super(MyConv2D, self).__init__()
|
|
|
self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device)
|
|
|
self.bias = torch.zeros(out_channels).to(device)
|
|
|
|
|
|
self.in_channels = in_channels
|
|
|
self.out_channels = out_channels
|
|
|
self.kernel_size = (kernel_size, kernel_size)
|
|
|
self.stride = (stride, stride)
|
|
|
self.groups = groups
|
|
|
|
|
|
def forward(self, x):
|
|
|
return F.conv2d(x, self.weight, self.bias, self.stride, groups=self.groups)
|
|
|
|
|
|
def extra_repr(self):
|
|
|
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
|
|
', stride={stride}')
|
|
|
return s.format(**self.__dict__)
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
def __init__(self, channels):
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
self.conv = nn.Sequential(
|
|
|
*ConvLayer(channels, channels, kernel_size=3, stride=1),
|
|
|
*ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.conv(x) + x
|
|
|
|
|
|
|
|
|
class ResNeXtBlock(nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1, cardinality=32, width_per_group=64):
|
|
|
super(ResNeXtBlock, self).__init__()
|
|
|
width = int(out_channels * (width_per_group / 64)) * cardinality # 转换通道数
|
|
|
# self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, padding=0) # 不改变尺寸
|
|
|
self.conv1 = MyConv2D(in_channels, width, kernel_size=1, stride=1)
|
|
|
self.bn1 = nn.BatchNorm2d(width)
|
|
|
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality) # stride=2,尺寸减半;stride=1,尺寸不变
|
|
|
# self.conv2 = MyConv2D(width, width, kernel_size=3, stride=stride, groups=cardinality)
|
|
|
self.bn2 = nn.BatchNorm2d(width)
|
|
|
# self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, padding=0) # 不改变尺寸
|
|
|
self.conv3 = MyConv2D(width, out_channels, kernel_size=1, stride=1)
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.identity_downsample = identity_downsample
|
|
|
|
|
|
def forward(self, x):
|
|
|
identity = x
|
|
|
x = self.conv1(x)
|
|
|
x = self.bn1(x)
|
|
|
x = self.relu(x)
|
|
|
x = self.conv2(x)
|
|
|
x = self.bn2(x)
|
|
|
x = self.relu(x)
|
|
|
x = self.conv3(x)
|
|
|
x = self.bn3(x)
|
|
|
|
|
|
if self.identity_downsample is not None:
|
|
|
identity = self.identity_downsample(identity)
|
|
|
# 残差连接
|
|
|
# print(x.shape)
|
|
|
# print(identity.shape)
|
|
|
x += identity
|
|
|
x = self.relu(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1,
|
|
|
upsample=None, instance_norm=True, relu=True, trainable=False):
|
|
|
layers = []
|
|
|
if upsample:
|
|
|
layers.append(nn.Upsample(mode='nearest', scale_factor=upsample))
|
|
|
layers.append(nn.ReflectionPad2d(kernel_size // 2))
|
|
|
if trainable:
|
|
|
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride))
|
|
|
else:
|
|
|
layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride))
|
|
|
if instance_norm:
|
|
|
layers.append(nn.InstanceNorm2d(out_channels))
|
|
|
if relu:
|
|
|
layers.append(nn.ReLU())
|
|
|
return layers
|
|
|
|
|
|
|
|
|
class TransformNet(nn.Module):
|
|
|
def __init__(self, base=8, residuals='resnet'):
|
|
|
super(TransformNet, self).__init__()
|
|
|
self.base = base
|
|
|
self.weights = []
|
|
|
self.downsampling = nn.Sequential(
|
|
|
*ConvLayer(3, base, kernel_size=9, trainable=True),
|
|
|
*ConvLayer(base, base * 2, kernel_size=3, stride=2),
|
|
|
*ConvLayer(base * 2, base * 4, kernel_size=3, stride=2),
|
|
|
)
|
|
|
if residuals == 'resnet':
|
|
|
self.residuals = nn.Sequential(*[ResidualBlock(base * 4) for i in range(5)])
|
|
|
elif residuals == 'resnext':
|
|
|
self.residuals = nn.Sequential(ResNeXtBlock(base * 4, base * 4))
|
|
|
self.upsampling = nn.Sequential(
|
|
|
*ConvLayer(base * 4, base * 2, kernel_size=3, upsample=2),
|
|
|
*ConvLayer(base * 2, base, kernel_size=3, upsample=2),
|
|
|
*ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True),
|
|
|
)
|
|
|
self.get_param_dict()
|
|
|
|
|
|
def forward(self, X):
|
|
|
y = self.downsampling(X)
|
|
|
y = self.residuals(y)
|
|
|
y = self.upsampling(y)
|
|
|
return y
|
|
|
|
|
|
def get_param_dict(self):
|
|
|
"""找出该网络所有 MyConv2D 层,计算它们需要的权值数量"""
|
|
|
param_dict = defaultdict(int)
|
|
|
|
|
|
def dfs(module, name):
|
|
|
for name2, layer in module.named_children():
|
|
|
dfs(layer, '%s.%s' % (name, name2) if name != '' else name2)
|
|
|
if module.__class__ == MyConv2D:
|
|
|
param_dict[name] += int(np.prod(module.weight.shape))
|
|
|
param_dict[name] += int(np.prod(module.bias.shape))
|
|
|
|
|
|
dfs(self, '')
|
|
|
return param_dict
|
|
|
|
|
|
def set_my_attr(self, name, value):
|
|
|
# 下面这个循环是一步步遍历类似 residuals.0.conv.1 的字符串,找到相应的权值
|
|
|
target = self
|
|
|
for x in name.split('.'):
|
|
|
if x.isnumeric():
|
|
|
target = target.__getitem__(int(x))
|
|
|
else:
|
|
|
target = getattr(target, x)
|
|
|
|
|
|
# 设置对应的权值
|
|
|
n_weight = np.prod(target.weight.shape)
|
|
|
target.weight = value[:n_weight].view(target.weight.shape)
|
|
|
target.bias = value[n_weight:].view(target.bias.shape)
|
|
|
|
|
|
def set_weights(self, weights, i=0):
|
|
|
"""输入权值字典,对该网络所有的 MyConv2D 层设置权值"""
|
|
|
for name, param in weights.items():
|
|
|
self.set_my_attr(name, weights[name][i])
|
|
|
|
|
|
|
|
|
class MetaNet(nn.Module):
|
|
|
def __init__(self, param_dict, backbone='vgg16'):
|
|
|
super(MetaNet, self).__init__()
|
|
|
self.param_num = len(param_dict)
|
|
|
if backbone == 'vgg16':
|
|
|
self.hidden = nn.Linear(1920, 128 * self.param_num)
|
|
|
elif backbone == 'vgg19':
|
|
|
self.hidden = nn.Linear(1920, 128 * self.param_num)
|
|
|
self.fc_dict = {}
|
|
|
for i, (name, params) in enumerate(param_dict.items()):
|
|
|
self.fc_dict[name] = i
|
|
|
setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params))
|
|
|
|
|
|
def forward(self, mean_std_features):
|
|
|
hidden = F.relu(self.hidden(mean_std_features))
|
|
|
filters = {}
|
|
|
for name, i in self.fc_dict.items():
|
|
|
fc = getattr(self, 'fc{}'.format(i + 1))
|
|
|
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
|
|
|
return filters
|