import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import torchvision.models as models import numpy as np from collections import defaultdict from utils import * device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class VGG(nn.Module): def __init__(self, features): super(VGG, self).__init__() self.features = features self.layer_name_mapping = { '3': "relu1_2", '8': "relu2_2", '15': "relu3_3", '22': "relu4_3" } for p in self.parameters(): p.requires_grad = False def forward(self, x): outs = [] for name, module in self.features._modules.items(): x = module(x) if name in self.layer_name_mapping: outs.append(x) return outs class MyConv2D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(MyConv2D, self).__init__() self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device) self.bias = torch.zeros(out_channels).to(device) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = (kernel_size, kernel_size) self.stride = (stride, stride) def forward(self, x): return F.conv2d(x, self.weight, self.bias, self.stride) def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__) class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv = nn.Sequential( *ConvLayer(channels, channels, kernel_size=3, stride=1), *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False) ) def forward(self, x): return self.conv(x) + x def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, upsample=None, instance_norm=True, relu=True, trainable=False): layers = [] if upsample: layers.append(nn.Upsample(mode='nearest', scale_factor=upsample)) layers.append(nn.ReflectionPad2d(kernel_size // 2)) if trainable: layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride)) else: layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride)) if instance_norm: layers.append(nn.InstanceNorm2d(out_channels)) if relu: layers.append(nn.ReLU()) return layers class TransformNet(nn.Module): def __init__(self, base=8): super(TransformNet, self).__init__() self.base = base self.weights = [] self.downsampling = nn.Sequential( *ConvLayer(3, base, kernel_size=9, trainable=True), *ConvLayer(base, base*2, kernel_size=3, stride=2), *ConvLayer(base*2, base*4, kernel_size=3, stride=2), ) self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)]) self.upsampling = nn.Sequential( *ConvLayer(base*4, base*2, kernel_size=3, upsample=2), *ConvLayer(base*2, base, kernel_size=3, upsample=2), *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True), ) self.get_param_dict() def forward(self, X): y = self.downsampling(X) y = self.residuals(y) y = self.upsampling(y) return y def get_param_dict(self): """找出该网络所有 MyConv2D 层,计算它们需要的权值数量""" param_dict = defaultdict(int) def dfs(module, name): for name2, layer in module.named_children(): dfs(layer, '%s.%s' % (name, name2) if name != '' else name2) if module.__class__ == MyConv2D: param_dict[name] += int(np.prod(module.weight.shape)) param_dict[name] += int(np.prod(module.bias.shape)) dfs(self, '') return param_dict def set_my_attr(self, name, value): # 下面这个循环是一步步遍历类似 residuals.0.conv.1 的字符串,找到相应的权值 target = self for x in name.split('.'): if x.isnumeric(): target = target.__getitem__(int(x)) else: target = getattr(target, x) # 设置对应的权值 n_weight = np.prod(target.weight.shape) target.weight = value[:n_weight].view(target.weight.shape) target.bias = value[n_weight:].view(target.bias.shape) def set_weights(self, weights, i=0): """输入权值字典,对该网络所有的 MyConv2D 层设置权值""" for name, param in weights.items(): self.set_my_attr(name, weights[name][i]) class MetaNet(nn.Module): def __init__(self, param_dict): super(MetaNet, self).__init__() self.param_num = len(param_dict) self.hidden = nn.Linear(1920, 128*self.param_num) self.fc_dict = {} for i, (name, params) in enumerate(param_dict.items()): self.fc_dict[name] = i setattr(self, 'fc{}'.format(i+1), nn.Linear(128, params)) def forward(self, mean_std_features): hidden = F.relu(self.hidden(mean_std_features)) filters = {} for name, i in self.fc_dict.items(): fc = getattr(self, 'fc{}'.format(i+1)) filters[name] = fc(hidden[:,i*128:(i+1)*128]) return filters