# -*- coding: utf-8 -*- """ @File : comparison.py @Author: csc @Date : 2022/6/28 """ import os os.environ['CUDA_VISIBLE_DEVICES'] = '5' import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms import torchvision.models as models from tensorboardX import SummaryWriter import random import shutil from glob import glob from tqdm import tqdm from utils import * from models import * device = 'cpu' class Config: def __init__(self, name, backbone_name, transform_net, meta_path, content_path, style_path, transform_path, style_weight): self.name = name self.backbone_name = backbone_name if backbone_name == 'vgg16': vgg16 = models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth')) self.backbone = VGG(vgg16.features[:23]).to(device).eval() elif backbone_name == 'vgg19': vgg19 = models.vgg19(pretrained=False) vgg19.load_state_dict(torch.load('./models/vgg19-dcbb9e9d.pth')) self.backbone = VGG19(vgg19.features[:30]).to(device).eval() self.transform_net = transform_net self.meta_path = meta_path self.transform_path = transform_path self.content_path = content_path self.style_path = style_path self.style_weight = style_weight self.transformed_images = None base = 32 configs = [ # Config(name='vgg16_500', # backbone_name='vgg16', # transform_net=TransformNet(base).to(device), # meta_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth', # transform_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_500/', # style_path='../WikiArt_500/', # style_weight=50), # Config(name='vgg16_1000', # backbone_name='vgg16', # transform_net=TransformNet(base).to(device), # meta_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth', # transform_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_1000/', # style_path='../WikiArt_1000/', # style_weight=50), # Config(name='resnext_gram_500', # backbone_name='vgg16', # transform_net=TransformNet(base, residuals='resnext').to(device), # meta_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth', # transform_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_500/', # style_path='../WikiArt_500/', # style_weight=3e5), # Config(name='resnext_gram_1000', # backbone_name='vgg16', # transform_net=TransformNet(base, residuals='resnext').to(device), # meta_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth', # transform_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_1000/', # style_path='../WikiArt_1000/', # style_weight=3e5), # Config(name='vgg19_resnext_gram_500', # backbone_name='vgg19', # transform_net=TransformNet(base, residuals='resnext').to(device), # meta_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth', # transform_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_500/', # style_path='../WikiArt_500/', # style_weight=3e5), Config(name='vgg19_resnext_gram_1000', backbone_name='vgg19', transform_net=TransformNet(base, residuals='resnext').to(device), meta_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth', transform_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth', content_path='../COCO2014_1000/', style_path='../WikiArt_1000/', style_weight=3e5), # 消融 # Config(name='vgg19_resnet_gram_500', # backbone_name='vgg19', # transform_net=TransformNet(base).to(device), # meta_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth', # transform_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_500/', # style_path='../WikiArt_500/', # style_weight=3e5), # Config(name='vgg19_resnext_mse_500', # backbone_name='vgg19', # transform_net=TransformNet(base, residuals='resnext').to(device), # meta_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd.pth', # transform_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth', # content_path='../COCO2014_500/', # style_path='../WikiArt_500/', # style_weight=50), ] content_weight = 1 tv_weight = 1e-6 batch_size = 8 # 可视化 width = 256 data_transform = transforms.Compose([ transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)), transforms.ToTensor(), tensor_normalizer ]) class MetaNet(nn.Module): def __init__(self, param_dict, backbone='vgg16'): super(MetaNet, self).__init__() self.param_num = len(param_dict) if backbone == 'vgg16': self.hidden = nn.Linear(1920, 128 * self.param_num) elif backbone == 'vgg19': self.hidden = nn.Linear(1920, 128 * self.param_num) self.fc_dict = {} for i, (name, params) in enumerate(param_dict.items()): self.fc_dict[name] = i setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params)) # ONNX 要求输出 tensor 或者 list,不能是 dict def forward(self, mean_std_features): hidden = F.relu(self.hidden(mean_std_features)) filters = {} for name, i in self.fc_dict.items(): fc = getattr(self, 'fc{}'.format(i + 1)) filters[name] = fc(hidden[:, i * 128:(i + 1) * 128]) return list(filters.values()) def forward2(self, mean_std_features): hidden = F.relu(self.hidden(mean_std_features)) filters = {} for name, i in self.fc_dict.items(): fc = getattr(self, 'fc{}'.format(i + 1)) filters[name] = fc(hidden[:, i * 128:(i + 1) * 128]) return filters graph = [[] for i in range(3)] for (index, config) in enumerate(configs): backbone_name = config.backbone_name backbone = config.backbone transform_net = config.transform_net meta_path = config.meta_path transform_path = config.transform_path content_path = config.content_path style_path = config.style_path style_weight = config.style_weight style_dataset = torchvision.datasets.ImageFolder(style_path, transform=data_transform) content_dataset = torchvision.datasets.ImageFolder(content_path, transform=data_transform) content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, shuffle=True) # style_img_name = '30925' # test_style_image = read_image(style_path + 'train/' + style_img_name + '.jpg', target_width=width).to(device) test_style_image = read_image('./images/pearl.jpg', target_width=width).to(device) style_features = backbone(test_style_image) style_mean_std = mean_std(style_features) metanet = MetaNet(transform_net.get_param_dict(), backbone=backbone_name).to(device) trainable_params = {} trainable_param_shapes = {} for model in [backbone, transform_net, metanet]: for name, param in model.named_parameters(): if param.requires_grad: trainable_params[name] = param trainable_param_shapes[name] = param.shape optimizer = optim.Adam(trainable_params.values(), 1e-3) metanet.load_state_dict(torch.load(meta_path)) transform_net.load_state_dict(torch.load(transform_path)) n_batch = 20 with tqdm(enumerate(content_data_loader), total=n_batch) as pbar: for batch, (content_images, _) in pbar: x = content_images.cpu().numpy() if (x.min(-1).min(-1) == x.max(-1).max(-1)).any(): continue optimizer.zero_grad() # 使用风格图像生成风格模型 weights = metanet.forward2(mean_std(style_features)) transform_net.set_weights(weights, 0) # 使用风格模型预测风格迁移图像 content_images = content_images.to(device) transformed_images = transform_net(content_images) # 计算特征 content_features = backbone(content_images) transformed_features = backbone(transformed_images) transformed_mean_std = mean_std(transformed_features) # content loss content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2]) # style loss style_loss = style_weight * F.mse_loss(transformed_mean_std, style_mean_std.expand_as(transformed_mean_std)) # total variation loss y = transformed_images tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) # 求和 loss = content_loss + style_loss + tv_loss loss.backward() optimizer.step() if batch > n_batch: break content_img_path = content_path + 'train2014/' content_img_names = [ content_img_path + 'COCO_train2014_000000081611.jpg', content_img_path + 'COCO_train2014_000000149739.jpg', content_img_path + 'COCO_train2014_000000505057.jpg', content_img_path + 'COCO_train2014_000000421773.jpg', './images/dancing.png', './images/boat.png', './images/ecnu.jpg', './images/text.jpg' ] test_content_images = torch.stack([read_image(name, target_width=width) for name in content_img_names]).to(device) content_images_vis = torch.cat([x for x in test_content_images], dim=-1) # config.transformed_images = transform_net(test_content_images) config.transformed_images = [transform_net(read_image(name, target_width=width).to(device)) for name in content_img_names] fig = plt.figure(figsize=(50, 20)) size = 310 line_len = len(configs) plt.subplot(size + 1) imshow(test_style_image) plt.subplot(size + line_len + 1) imshow(content_images_vis) for (index, config) in enumerate(configs): plt.subplot(size + line_len * (index + 2) + 1) plt.title(config.name) transformed_images_vis = torch.cat([x for x in config.transformed_images], dim=-1) imshow(transformed_images_vis) fig.savefig('./images/out.png')