|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
@File : comparison.py
|
|
|
|
|
@Author: csc
|
|
|
|
|
@Date : 2022/6/28
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
|
|
|
|
import torchvision
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
|
|
|
|
import random
|
|
|
|
|
import shutil
|
|
|
|
|
from glob import glob
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from utils import *
|
|
|
|
|
from models import *
|
|
|
|
|
|
|
|
|
|
device = 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
def __init__(self, name, backbone_name, transform_net, meta_path, content_path, style_path, transform_path, style_weight):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.backbone_name = backbone_name
|
|
|
|
|
if backbone_name == 'vgg16':
|
|
|
|
|
vgg16 = models.vgg16(pretrained=False)
|
|
|
|
|
vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
|
|
|
|
|
self.backbone = VGG(vgg16.features[:23]).to(device).eval()
|
|
|
|
|
elif backbone_name == 'vgg19':
|
|
|
|
|
vgg19 = models.vgg19(pretrained=False)
|
|
|
|
|
vgg19.load_state_dict(torch.load('./models/vgg19-dcbb9e9d.pth'))
|
|
|
|
|
self.backbone = VGG19(vgg19.features[:30]).to(device).eval()
|
|
|
|
|
self.transform_net = transform_net
|
|
|
|
|
self.meta_path = meta_path
|
|
|
|
|
self.transform_path = transform_path
|
|
|
|
|
self.content_path = content_path
|
|
|
|
|
self.style_path = style_path
|
|
|
|
|
self.style_weight = style_weight
|
|
|
|
|
self.transformed_images = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base = 32
|
|
|
|
|
configs = [
|
|
|
|
|
# Config(name='vgg16_500',
|
|
|
|
|
# backbone_name='vgg16',
|
|
|
|
|
# transform_net=TransformNet(base).to(device),
|
|
|
|
|
# meta_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_500/',
|
|
|
|
|
# style_path='../WikiArt_500/',
|
|
|
|
|
# style_weight=50),
|
|
|
|
|
# Config(name='vgg16_1000',
|
|
|
|
|
# backbone_name='vgg16',
|
|
|
|
|
# transform_net=TransformNet(base).to(device),
|
|
|
|
|
# meta_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_1000/',
|
|
|
|
|
# style_path='../WikiArt_1000/',
|
|
|
|
|
# style_weight=50),
|
|
|
|
|
# Config(name='resnext_gram_500',
|
|
|
|
|
# backbone_name='vgg16',
|
|
|
|
|
# transform_net=TransformNet(base, residuals='resnext').to(device),
|
|
|
|
|
# meta_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_500/',
|
|
|
|
|
# style_path='../WikiArt_500/',
|
|
|
|
|
# style_weight=3e5),
|
|
|
|
|
# Config(name='resnext_gram_1000',
|
|
|
|
|
# backbone_name='vgg16',
|
|
|
|
|
# transform_net=TransformNet(base, residuals='resnext').to(device),
|
|
|
|
|
# meta_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_1000/',
|
|
|
|
|
# style_path='../WikiArt_1000/',
|
|
|
|
|
# style_weight=3e5),
|
|
|
|
|
# Config(name='vgg19_resnext_gram_500',
|
|
|
|
|
# backbone_name='vgg19',
|
|
|
|
|
# transform_net=TransformNet(base, residuals='resnext').to(device),
|
|
|
|
|
# meta_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_500/',
|
|
|
|
|
# style_path='../WikiArt_500/',
|
|
|
|
|
# style_weight=3e5),
|
|
|
|
|
Config(name='vgg19_resnext_gram_1000',
|
|
|
|
|
backbone_name='vgg19',
|
|
|
|
|
transform_net=TransformNet(base, residuals='resnext').to(device),
|
|
|
|
|
meta_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
|
|
|
|
|
transform_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
content_path='../COCO2014_1000/',
|
|
|
|
|
style_path='../WikiArt_1000/',
|
|
|
|
|
style_weight=3e5),
|
|
|
|
|
# 消融
|
|
|
|
|
# Config(name='vgg19_resnet_gram_500',
|
|
|
|
|
# backbone_name='vgg19',
|
|
|
|
|
# transform_net=TransformNet(base).to(device),
|
|
|
|
|
# meta_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_500/',
|
|
|
|
|
# style_path='../WikiArt_500/',
|
|
|
|
|
# style_weight=3e5),
|
|
|
|
|
# Config(name='vgg19_resnext_mse_500',
|
|
|
|
|
# backbone_name='vgg19',
|
|
|
|
|
# transform_net=TransformNet(base, residuals='resnext').to(device),
|
|
|
|
|
# meta_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd.pth',
|
|
|
|
|
# transform_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
|
|
|
|
|
# content_path='../COCO2014_500/',
|
|
|
|
|
# style_path='../WikiArt_500/',
|
|
|
|
|
# style_weight=50),
|
|
|
|
|
]
|
|
|
|
|
content_weight = 1
|
|
|
|
|
tv_weight = 1e-6
|
|
|
|
|
batch_size = 8
|
|
|
|
|
|
|
|
|
|
# 可视化
|
|
|
|
|
width = 256
|
|
|
|
|
data_transform = transforms.Compose([
|
|
|
|
|
transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)),
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
tensor_normalizer
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetaNet(nn.Module):
|
|
|
|
|
def __init__(self, param_dict, backbone='vgg16'):
|
|
|
|
|
super(MetaNet, self).__init__()
|
|
|
|
|
self.param_num = len(param_dict)
|
|
|
|
|
if backbone == 'vgg16':
|
|
|
|
|
self.hidden = nn.Linear(1920, 128 * self.param_num)
|
|
|
|
|
elif backbone == 'vgg19':
|
|
|
|
|
self.hidden = nn.Linear(1920, 128 * self.param_num)
|
|
|
|
|
self.fc_dict = {}
|
|
|
|
|
for i, (name, params) in enumerate(param_dict.items()):
|
|
|
|
|
self.fc_dict[name] = i
|
|
|
|
|
setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params))
|
|
|
|
|
|
|
|
|
|
# ONNX 要求输出 tensor 或者 list,不能是 dict
|
|
|
|
|
def forward(self, mean_std_features):
|
|
|
|
|
hidden = F.relu(self.hidden(mean_std_features))
|
|
|
|
|
filters = {}
|
|
|
|
|
for name, i in self.fc_dict.items():
|
|
|
|
|
fc = getattr(self, 'fc{}'.format(i + 1))
|
|
|
|
|
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
|
|
|
|
|
return list(filters.values())
|
|
|
|
|
|
|
|
|
|
def forward2(self, mean_std_features):
|
|
|
|
|
hidden = F.relu(self.hidden(mean_std_features))
|
|
|
|
|
filters = {}
|
|
|
|
|
for name, i in self.fc_dict.items():
|
|
|
|
|
fc = getattr(self, 'fc{}'.format(i + 1))
|
|
|
|
|
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
|
|
|
|
|
return filters
|
|
|
|
|
|
|
|
|
|
graph = [[] for i in range(3)]
|
|
|
|
|
|
|
|
|
|
for (index, config) in enumerate(configs):
|
|
|
|
|
backbone_name = config.backbone_name
|
|
|
|
|
backbone = config.backbone
|
|
|
|
|
transform_net = config.transform_net
|
|
|
|
|
meta_path = config.meta_path
|
|
|
|
|
transform_path = config.transform_path
|
|
|
|
|
content_path = config.content_path
|
|
|
|
|
style_path = config.style_path
|
|
|
|
|
style_weight = config.style_weight
|
|
|
|
|
|
|
|
|
|
style_dataset = torchvision.datasets.ImageFolder(style_path, transform=data_transform)
|
|
|
|
|
content_dataset = torchvision.datasets.ImageFolder(content_path, transform=data_transform)
|
|
|
|
|
|
|
|
|
|
content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
|
# style_img_name = '30925'
|
|
|
|
|
# test_style_image = read_image(style_path + 'train/' + style_img_name + '.jpg', target_width=width).to(device)
|
|
|
|
|
test_style_image = read_image('./images/pearl.jpg', target_width=width).to(device)
|
|
|
|
|
|
|
|
|
|
style_features = backbone(test_style_image)
|
|
|
|
|
style_mean_std = mean_std(style_features)
|
|
|
|
|
metanet = MetaNet(transform_net.get_param_dict(), backbone=backbone_name).to(device)
|
|
|
|
|
|
|
|
|
|
trainable_params = {}
|
|
|
|
|
trainable_param_shapes = {}
|
|
|
|
|
for model in [backbone, transform_net, metanet]:
|
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
|
if param.requires_grad:
|
|
|
|
|
trainable_params[name] = param
|
|
|
|
|
trainable_param_shapes[name] = param.shape
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(trainable_params.values(), 1e-3)
|
|
|
|
|
|
|
|
|
|
metanet.load_state_dict(torch.load(meta_path))
|
|
|
|
|
transform_net.load_state_dict(torch.load(transform_path))
|
|
|
|
|
|
|
|
|
|
n_batch = 20
|
|
|
|
|
with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:
|
|
|
|
|
for batch, (content_images, _) in pbar:
|
|
|
|
|
x = content_images.cpu().numpy()
|
|
|
|
|
if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
# 使用风格图像生成风格模型
|
|
|
|
|
weights = metanet.forward2(mean_std(style_features))
|
|
|
|
|
transform_net.set_weights(weights, 0)
|
|
|
|
|
|
|
|
|
|
# 使用风格模型预测风格迁移图像
|
|
|
|
|
content_images = content_images.to(device)
|
|
|
|
|
transformed_images = transform_net(content_images)
|
|
|
|
|
|
|
|
|
|
# 计算特征
|
|
|
|
|
content_features = backbone(content_images)
|
|
|
|
|
transformed_features = backbone(transformed_images)
|
|
|
|
|
transformed_mean_std = mean_std(transformed_features)
|
|
|
|
|
|
|
|
|
|
# content loss
|
|
|
|
|
content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])
|
|
|
|
|
|
|
|
|
|
# style loss
|
|
|
|
|
style_loss = style_weight * F.mse_loss(transformed_mean_std,
|
|
|
|
|
style_mean_std.expand_as(transformed_mean_std))
|
|
|
|
|
|
|
|
|
|
# total variation loss
|
|
|
|
|
y = transformed_images
|
|
|
|
|
tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
|
|
|
|
|
torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))
|
|
|
|
|
|
|
|
|
|
# 求和
|
|
|
|
|
loss = content_loss + style_loss + tv_loss
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
if batch > n_batch:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
content_img_path = content_path + 'train2014/'
|
|
|
|
|
content_img_names = [
|
|
|
|
|
content_img_path + 'COCO_train2014_000000081611.jpg',
|
|
|
|
|
content_img_path + 'COCO_train2014_000000149739.jpg',
|
|
|
|
|
content_img_path + 'COCO_train2014_000000505057.jpg',
|
|
|
|
|
content_img_path + 'COCO_train2014_000000421773.jpg',
|
|
|
|
|
'./images/dancing.png',
|
|
|
|
|
'./images/boat.png',
|
|
|
|
|
'./images/ecnu.jpg',
|
|
|
|
|
'./images/text.jpg'
|
|
|
|
|
]
|
|
|
|
|
test_content_images = torch.stack([read_image(name, target_width=width) for name in content_img_names]).to(device)
|
|
|
|
|
content_images_vis = torch.cat([x for x in test_content_images], dim=-1)
|
|
|
|
|
|
|
|
|
|
# config.transformed_images = transform_net(test_content_images)
|
|
|
|
|
config.transformed_images = [transform_net(read_image(name, target_width=width).to(device)) for name in content_img_names]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(50, 20))
|
|
|
|
|
size = 310
|
|
|
|
|
line_len = len(configs)
|
|
|
|
|
plt.subplot(size + 1)
|
|
|
|
|
imshow(test_style_image)
|
|
|
|
|
plt.subplot(size + line_len + 1)
|
|
|
|
|
imshow(content_images_vis)
|
|
|
|
|
for (index, config) in enumerate(configs):
|
|
|
|
|
plt.subplot(size + line_len * (index + 2) + 1)
|
|
|
|
|
plt.title(config.name)
|
|
|
|
|
transformed_images_vis = torch.cat([x for x in config.transformed_images], dim=-1)
|
|
|
|
|
imshow(transformed_images_vis)
|
|
|
|
|
|
|
|
|
|
fig.savefig('./images/out.png')
|