You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
5.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
@File : situation3_display.py
@Author: csc
@Date : 2022/6/24
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from tensorboardX import SummaryWriter
import random
import shutil
from glob import glob
from tqdm import tqdm
from utils import *
from models import *
device = 'cpu'
vgg16 = models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
vgg16 = VGG(vgg16.features[:23]).to(device).eval()
base = 32
transform_net = TransformNet(base).to(device)
class MetaNet(nn.Module):
def __init__(self, param_dict):
super(MetaNet, self).__init__()
self.param_num = len(param_dict)
self.hidden = nn.Linear(1920, 128 * self.param_num)
self.fc_dict = {}
for i, (name, params) in enumerate(param_dict.items()):
self.fc_dict[name] = i
setattr(self, 'fc{}'.format(i + 1), nn.Linear(128, params))
# ONNX 要求输出 tensor 或者 list不能是 dict
def forward(self, mean_std_features):
hidden = F.relu(self.hidden(mean_std_features))
filters = {}
for name, i in self.fc_dict.items():
fc = getattr(self, 'fc{}'.format(i + 1))
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
return list(filters.values())
def forward2(self, mean_std_features):
hidden = F.relu(self.hidden(mean_std_features))
filters = {}
for name, i in self.fc_dict.items():
fc = getattr(self, 'fc{}'.format(i + 1))
filters[name] = fc(hidden[:, i * 128:(i + 1) * 128])
return filters
metanet = MetaNet(transform_net.get_param_dict()).to(device)
# 可视化
width = 256
data_transform = transforms.Compose([
transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)),
transforms.ToTensor(),
tensor_normalizer
])
style_dataset = torchvision.datasets.ImageFolder('../WikiArt_500/', transform=data_transform)
content_dataset = torchvision.datasets.ImageFolder('../COCO2014_500/', transform=data_transform)
style_weight = 50
content_weight = 1
tv_weight = 1e-6
batch_size = 8
trainable_params = {}
trainable_param_shapes = {}
for model in [vgg16, transform_net, metanet]:
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params[name] = param
trainable_param_shapes[name] = param.shape
optimizer = optim.Adam(trainable_params.values(), 1e-3)
content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, shuffle=True)
style_image = read_image('../WikiArt_500/train/29052.jpg', target_width=256).to(device) # 18993
# style_image = read_image('./images/mosaic.png', target_width=256).to(device)
style_features = vgg16(style_image)
style_mean_std = mean_std(style_features)
# metanet.load_state_dict(torch.load('models/metanet_base32_style100000.0_tv1e-06_tagnohvd.pth'))
# transform_net.load_state_dict(torch.load('models/metanet_base32_style100000.0_tv1e-06_tagnohvd_transform_net.pth'))
metanet.load_state_dict(torch.load('../weight/500_22ep/metanet_base32_style50_tv1e-06_tagnohvd.pth'))
transform_net.load_state_dict(torch.load('../weight/500_22ep/metanet_base32_style50_tv1e-06_tagnohvd_transform_net.pth'))
n_batch = 20
with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:
for batch, (content_images, _) in pbar:
x = content_images.cpu().numpy()
if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():
continue
optimizer.zero_grad()
# 使用风格图像生成风格模型
weights = metanet.forward2(mean_std(style_features))
transform_net.set_weights(weights, 0)
# 使用风格模型预测风格迁移图像
content_images = content_images.to(device)
transformed_images = transform_net(content_images)
# 使用 vgg16 计算特征
content_features = vgg16(content_images)
transformed_features = vgg16(transformed_images)
transformed_mean_std = mean_std(transformed_features)
# content loss
content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])
# style loss
style_loss = style_weight * F.mse_loss(transformed_mean_std,
style_mean_std.expand_as(transformed_mean_std))
# total variation loss
y = transformed_images
tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))
# 求和
loss = content_loss + style_loss + tv_loss
loss.backward()
optimizer.step()
if batch > n_batch:
break
# content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)
content_img_path = '../smallCOCO2014/train2014/'
style_img_name = '18993'
content_img_name = 'COCO_train2014_000000505057'
# content_images = read_image(content_img_path + content_img_name + '.jpg', target_width=width).to(device)
content_images = read_image('./images/text.jpg', target_width=width).to(device)
# while content_images.min() < -2:
# print('.', end=' ')
# content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)
transformed_images = transform_net(content_images)
transformed_images_vis = torch.cat([x for x in transformed_images], dim=-1)
content_images_vis = torch.cat([x for x in content_images], dim=-1)
fig = plt.figure(figsize=(20, 12))
plt.subplot(3, 1, 1)
imshow(style_image)
plt.subplot(3, 1, 2)
imshow(content_images_vis)
plt.subplot(3, 1, 3)
imshow(transformed_images_vis)
fig.savefig('./images/out.png')