# -*- coding: utf-8 -*- """ @File : models.py @Author: csc @Date : 2022/6/23 """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import torchvision.models as models import numpy as np from collections import defaultdict from utils import * device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = 'cpu' class VGG(nn.Module): def __init__(self, features): super(VGG, self).__init__() self.features = features self.layer_name_mapping = { '3': "relu1_2", '8': "relu2_2", '15': "relu3_3", '22': "relu4_3" } for p in self.parameters(): p.requires_grad = False def forward(self, x): outs = [] for name, module in self.features._modules.items(): x = module(x) if name in self.layer_name_mapping: outs.append(x) return outs class VGG19(nn.Module): def __init__(self, features): super(VGG19, self).__init__() self.features = features self.layer_name_mapping = { '3': "relu1_2", '8': "relu2_2", '17': "relu3_4", '26': "relu4_4" } for p in self.parameters(): p.requires_grad = False def forward(self, x): outs = [] for name, module in self.features._modules.items(): x = module(x) if name in self.layer_name_mapping: outs.append(x) return outs class MyConv2D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1): super(MyConv2D, self).__init__() self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device) self.bias = torch.zeros(out_channels).to(device) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = (kernel_size, kernel_size) self.stride = (stride, stride) self.groups = groups def forward(self, x): return F.conv2d(x, self.weight, self.bias, self.stride, groups=self.groups) def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__) class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv = nn.Sequential( *ConvLayer(channels, channels, kernel_size=3, stride=1), *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False) ) def forward(self, x): return self.conv(x) + x class ResNeXtBlock(nn.Module): def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1, cardinality=32, width_per_group=64): super(ResNeXtBlock, self).__init__() width = int(out_channels * (width_per_group / 64)) * cardinality # 转换通道数 # self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, padding=0) # 不改变尺寸 self.conv1 = MyConv2D(in_channels, width, kernel_size=1, stride=1) self.bn1 = nn.BatchNorm2d(width) self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality) # stride=2,尺寸减半;stride=1,尺寸不变 # self.conv2 = MyConv2D(width, width, kernel_size=3, stride=stride, groups=cardinality) self.bn2 = nn.BatchNorm2d(width) # self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, padding=0) # 不改变尺寸 self.conv3 = MyConv2D(width, out_channels, kernel_size=1, stride=1) self.bn3 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.identity_downsample = identity_downsample def forward(self, x): identity = x x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) if self.identity_downsample is not None: identity = self.identity_downsample(identity) # 残差连接 # print(x.shape) # print(identity.shape) x += identity x = self.relu(x) return x def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, upsample=None, instance_norm=True, relu=True, trainable=False): layers = [] if upsample: layers.append(nn.Upsample(mode='nearest', scale_factor=upsample)) layers.append(nn.ReflectionPad2d(kernel_size // 2)) if trainable: layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride)) else: layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride)) if instance_norm: layers.append(nn.InstanceNorm2d(out_channels)) if relu: layers.append(nn.ReLU()) return layers class TransformNet(nn.Module): def __init__(self, base=8, residuals='resnet'): super(TransformNet, self).__init__() self.base = base self.weights = [] self.downsampling = nn.Sequential( *ConvLayer(3, base, kernel_size=9, trainable=True), *ConvLayer(base, base * 2, kernel_size=3, stride=2), *ConvLayer(base * 2, base * 4, kernel_size=3, stride=2), ) if residuals == 'resnet': self.residuals = nn.Sequential(*[ResidualBlock(base * 4) for i in range(5)]) elif residuals == 'resnext': self.residuals = nn.Sequential(ResNeXtBlock(base * 4, base * 4)) self.upsampling = nn.Sequential( *ConvLayer(base * 4, base * 2, kernel_size=3, upsample=2), *ConvLayer(base * 2, base, kernel_size=3, upsample=2), *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True), ) self.get_param_dict() def forward(self, X): y = self.downsampling(X) y = self.residuals(y) y = self.upsampling(y) return y def get_param_dict(self): """找出该网络所有 MyConv2D 层,计算它们需要的权值数量""" param_dict = defaultdict(int) def dfs(module, name): for name2, layer in module.named_children(): dfs(layer, '%s.%s' % (name, name2) if name != '' else name2) if module.__class__ == MyConv2D: param_dict[name] += int(np.prod(module.weight.shape)) param_dict[name] += int(np.prod(module.bias.shape)) dfs(self, '') return param_dict def set_my_attr(self, name, value): # 下面这个循环是一步步遍历类似 residuals.0.conv.1 的字符串,找到相应的权值 target = self for x in name.split('.'): if x.isnumeric(): target = target.__getitem__(int(x)) else: target = getattr(target, x) # 设置对应的权值 n_weight = np.prod(target.weight.shape) target.weight = value[:n_weight].view(target.weight.shape) target.bias = value[n_weight:].view(target.bias.shape) def set_weights(self, weights, i=0): """输入权值字典,对该网络所有的 MyConv2D 层设置权值""" for name, param in weights.items(): self.set_my_attr(name, weights[name][i]) class MetaNet(nn.Module): def __init__(self, param_dict, backbone='vgg16'): super(MetaNet, self).__init__() self.param_num = len(param_dict) if backbone == 'vgg16': self.hidden = nn.Linear(1920, 128 * self.param_num) elif backbone == 'vgg19': self.hidden = nn.Linear(1920, 128 * self.param_num) self.fc_dict = {} for i, (name, params) in enumerate(param_dict.items()): self.fc_dict[name] = i setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params)) def forward(self, mean_std_features): hidden = F.relu(self.hidden(mean_std_features)) filters = {} for name, i in self.fc_dict.items(): fc = getattr(self, 'fc{}'.format(i + 1)) filters[name] = fc(hidden[:, i * 128:(i + 1) * 128]) return filters