You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
@File : comparison.py
@Author: csc
@Date : 2022/6/28
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from tensorboardX import SummaryWriter
import random
import shutil
from glob import glob
from tqdm import tqdm
from utils import *
from models import *
device = 'cpu'
class Config:
def __init__(self, name, backbone_name, transform_net, meta_path, content_path, style_path, transform_path, style_weight):
self.name = name
self.backbone_name = backbone_name
if backbone_name == 'vgg16':
vgg16 = models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
self.backbone = VGG(vgg16.features[:23]).to(device).eval()
elif backbone_name == 'vgg19':
vgg19 = models.vgg19(pretrained=False)
vgg19.load_state_dict(torch.load('./models/vgg19-dcbb9e9d.pth'))
self.backbone = VGG19(vgg19.features[:30]).to(device).eval()
self.transform_net = transform_net
self.meta_path = meta_path
self.transform_path = transform_path
self.content_path = content_path
self.style_path = style_path
self.style_weight = style_weight
self.transformed_images = None
base = 32
configs = [
# Config(name='vgg16_500',
# backbone_name='vgg16',
# transform_net=TransformNet(base).to(device),
# meta_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth',
# transform_path='../weight/new_500_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_500/',
# style_path='../WikiArt_500/',
# style_weight=50),
# Config(name='vgg16_1000',
# backbone_name='vgg16',
# transform_net=TransformNet(base).to(device),
# meta_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth',
# transform_path='../weight/1000_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_1000/',
# style_path='../WikiArt_1000/',
# style_weight=50),
# Config(name='resnext_gram_500',
# backbone_name='vgg16',
# transform_net=TransformNet(base, residuals='resnext').to(device),
# meta_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
# transform_path='../weight/resnext_gram_500_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_500/',
# style_path='../WikiArt_500/',
# style_weight=3e5),
# Config(name='resnext_gram_1000',
# backbone_name='vgg16',
# transform_net=TransformNet(base, residuals='resnext').to(device),
# meta_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
# transform_path='../weight/resnext_gram_1000_22ep_3e5/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_1000/',
# style_path='../WikiArt_1000/',
# style_weight=3e5),
# Config(name='vgg19_resnext_gram_500',
# backbone_name='vgg19',
# transform_net=TransformNet(base, residuals='resnext').to(device),
# meta_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
# transform_path='../weight/vgg19_resnext_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_500/',
# style_path='../WikiArt_500/',
# style_weight=3e5),
Config(name='vgg19_resnext_gram_1000',
backbone_name='vgg19',
transform_net=TransformNet(base, residuals='resnext').to(device),
meta_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
transform_path='../weight/vgg19_resnext_gram_1000/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
content_path='../COCO2014_1000/',
style_path='../WikiArt_1000/',
style_weight=3e5),
# 消融
# Config(name='vgg19_resnet_gram_500',
# backbone_name='vgg19',
# transform_net=TransformNet(base).to(device),
# meta_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd.pth',
# transform_path='../weight/vgg19_resnet_gram_500/metanet_base32_style300000.0_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_500/',
# style_path='../WikiArt_500/',
# style_weight=3e5),
# Config(name='vgg19_resnext_mse_500',
# backbone_name='vgg19',
# transform_net=TransformNet(base, residuals='resnext').to(device),
# meta_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd.pth',
# transform_path='../weight/vgg19_resnext_mse_500/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth',
# content_path='../COCO2014_500/',
# style_path='../WikiArt_500/',
# style_weight=50),
]
content_weight = 1
tv_weight = 1e-6
batch_size = 8
# 可视化
width = 256
data_transform = transforms.Compose([
transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)),
transforms.ToTensor(),
tensor_normalizer
])
class MetaNet(nn.Module):
def __init__(self, param_dict, backbone='vgg16'):
super(MetaNet, self).__init__()
self.param_num = len(param_dict)
if backbone == 'vgg16':
self.hidden = nn.Linear(1920, 128 * self.param_num)
elif backbone == 'vgg19':
self.hidden = nn.Linear(1920, 128 * self.param_num)
self.fc_dict = {}
for i, (name, params) in enumerate(param_dict.items()):
self.fc_dict[name] = i
setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params))
# ONNX 要求输出 tensor 或者 list不能是 dict
def forward(self, mean_std_features):
hidden = F.relu(self.hidden(mean_std_features))
filters = {}
for name, i in self.fc_dict.items():
fc = getattr(self, 'fc{}'.format(i + 1))
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
return list(filters.values())
def forward2(self, mean_std_features):
hidden = F.relu(self.hidden(mean_std_features))
filters = {}
for name, i in self.fc_dict.items():
fc = getattr(self, 'fc{}'.format(i + 1))
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
return filters
graph = [[] for i in range(3)]
for (index, config) in enumerate(configs):
backbone_name = config.backbone_name
backbone = config.backbone
transform_net = config.transform_net
meta_path = config.meta_path
transform_path = config.transform_path
content_path = config.content_path
style_path = config.style_path
style_weight = config.style_weight
style_dataset = torchvision.datasets.ImageFolder(style_path, transform=data_transform)
content_dataset = torchvision.datasets.ImageFolder(content_path, transform=data_transform)
content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, shuffle=True)
# style_img_name = '30925'
# test_style_image = read_image(style_path + 'train/' + style_img_name + '.jpg', target_width=width).to(device)
test_style_image = read_image('./images/pearl.jpg', target_width=width).to(device)
style_features = backbone(test_style_image)
style_mean_std = mean_std(style_features)
metanet = MetaNet(transform_net.get_param_dict(), backbone=backbone_name).to(device)
trainable_params = {}
trainable_param_shapes = {}
for model in [backbone, transform_net, metanet]:
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params[name] = param
trainable_param_shapes[name] = param.shape
optimizer = optim.Adam(trainable_params.values(), 1e-3)
metanet.load_state_dict(torch.load(meta_path))
transform_net.load_state_dict(torch.load(transform_path))
n_batch = 20
with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:
for batch, (content_images, _) in pbar:
x = content_images.cpu().numpy()
if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():
continue
optimizer.zero_grad()
# 使用风格图像生成风格模型
weights = metanet.forward2(mean_std(style_features))
transform_net.set_weights(weights, 0)
# 使用风格模型预测风格迁移图像
content_images = content_images.to(device)
transformed_images = transform_net(content_images)
# 计算特征
content_features = backbone(content_images)
transformed_features = backbone(transformed_images)
transformed_mean_std = mean_std(transformed_features)
# content loss
content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])
# style loss
style_loss = style_weight * F.mse_loss(transformed_mean_std,
style_mean_std.expand_as(transformed_mean_std))
# total variation loss
y = transformed_images
tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))
# 求和
loss = content_loss + style_loss + tv_loss
loss.backward()
optimizer.step()
if batch > n_batch:
break
content_img_path = content_path + 'train2014/'
content_img_names = [
content_img_path + 'COCO_train2014_000000081611.jpg',
content_img_path + 'COCO_train2014_000000149739.jpg',
content_img_path + 'COCO_train2014_000000505057.jpg',
content_img_path + 'COCO_train2014_000000421773.jpg',
'./images/dancing.png',
'./images/boat.png',
'./images/ecnu.jpg',
'./images/text.jpg'
]
test_content_images = torch.stack([read_image(name, target_width=width) for name in content_img_names]).to(device)
content_images_vis = torch.cat([x for x in test_content_images], dim=-1)
# config.transformed_images = transform_net(test_content_images)
config.transformed_images = [transform_net(read_image(name, target_width=width).to(device)) for name in content_img_names]
fig = plt.figure(figsize=(50, 20))
size = 310
line_len = len(configs)
plt.subplot(size + 1)
imshow(test_style_image)
plt.subplot(size + line_len + 1)
imshow(content_images_vis)
for (index, config) in enumerate(configs):
plt.subplot(size + line_len * (index + 2) + 1)
plt.title(config.name)
transformed_images_vis = torch.cat([x for x in config.transformed_images], dim=-1)
imshow(transformed_images_vis)
fig.savefig('./images/out.png')