You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
284 lines
9.7 KiB
284 lines
9.7 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
@File : situation3.py
|
|
@Author: csc
|
|
@Date : 2022/6/24
|
|
"""
|
|
import os
|
|
|
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '4'
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
import random
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import torchvision.models as models
|
|
import shutil
|
|
from glob import glob
|
|
|
|
# from tensorboardX import SummaryWriter
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import numpy as np
|
|
import multiprocessing
|
|
|
|
import copy
|
|
from tqdm import tqdm
|
|
from collections import defaultdict
|
|
|
|
# import horovod.torch as hvd
|
|
import torch.utils.data.distributed
|
|
|
|
from utils import *
|
|
from models import *
|
|
import time
|
|
|
|
from pprint import pprint
|
|
display = pprint
|
|
|
|
# hvd.init()
|
|
# torch.cuda.set_device(hvd.local_rank())
|
|
# device = torch.device("cuda:%s" %hvd.local_rank() if torch.cuda.is_available() else "cpu")
|
|
device = 'cpu'
|
|
|
|
|
|
class ModelConfig:
|
|
vgg19 = True
|
|
resnext = True
|
|
gram = True
|
|
|
|
|
|
is_hvd = False
|
|
tag = 'nohvd'
|
|
base = 32
|
|
if ModelConfig.gram:
|
|
style_weight = 3e5
|
|
else:
|
|
style_weight = 50
|
|
content_weight = 1
|
|
tv_weight = 1e-6
|
|
epochs = 22
|
|
|
|
batch_size = 8
|
|
width = 256
|
|
verbose_hist_batch = 40 # 100
|
|
verbose_image_batch = 40 # 800
|
|
model_name = f'metanet_base{base}_style{style_weight}_tv{tv_weight}_tag{tag}'
|
|
# print(f'model_name: {model_name}, rank: {hvd.rank()}')
|
|
|
|
|
|
def rmrf(path):
|
|
try:
|
|
shutil.rmtree(path)
|
|
except:
|
|
pass
|
|
|
|
|
|
rmrf('runs/' + model_name)
|
|
|
|
# 16 -> 23; 19 -> 27
|
|
if ModelConfig.vgg19:
|
|
backbone = models.vgg19(pretrained=False)
|
|
backbone.load_state_dict(torch.load('./models/vgg19-dcbb9e9d.pth'))
|
|
backbone = VGG19(backbone.features[:27]).to(device).eval()
|
|
else:
|
|
backbone = models.vgg16(pretrained=False)
|
|
backbone.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
|
|
backbone = VGG(backbone.features[:23].to(device)).eval()
|
|
|
|
if ModelConfig.resnext:
|
|
transform_net = TransformNet(base, residuals='resnext').to(device)
|
|
else:
|
|
transform_net = TransformNet(base).to(device)
|
|
transform_net.get_param_dict()
|
|
|
|
metanet = MetaNet(transform_net.get_param_dict(),
|
|
backbone=('vgg19' if ModelConfig.vgg19 else 'vgg16')).to(device)
|
|
|
|
data_transform = transforms.Compose([
|
|
transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)),
|
|
transforms.ToTensor(),
|
|
tensor_normalizer
|
|
])
|
|
|
|
style_dataset = torchvision.datasets.ImageFolder('../WikiArt_1000/', transform=data_transform)
|
|
content_dataset = torchvision.datasets.ImageFolder('../COCO2014_1000/', transform=data_transform)
|
|
|
|
content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size,
|
|
shuffle=True, num_workers=0)
|
|
|
|
print(style_dataset)
|
|
print('-'*20)
|
|
print(content_dataset)
|
|
|
|
metanet.eval()
|
|
transform_net.eval()
|
|
|
|
rands = torch.rand(8, 3, 256, 256).to(device)
|
|
features = backbone(rands)
|
|
weights = metanet(mean_std(features))
|
|
transform_net.set_weights(weights)
|
|
transformed_images = transform_net(torch.rand(8, 3, 256, 256).to(device))
|
|
|
|
print('features:')
|
|
display([x.shape for x in features])
|
|
print('weights:')
|
|
display([x.shape for x in weights.values()])
|
|
print('transformed_images:')
|
|
display(transformed_images.shape)
|
|
|
|
# visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)
|
|
# visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)
|
|
|
|
rmrf('runs/' + model_name)
|
|
writer = SummaryWriter('runs/'+model_name)
|
|
|
|
# visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)
|
|
visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)
|
|
# writer.add_images('content_image', recover_tensor(visualization_content_images), 0)
|
|
# writer.add_graph(transform_net, (rands, ))
|
|
del rands, features, weights, transformed_images
|
|
|
|
trainable_params = {}
|
|
trainable_param_shapes = {}
|
|
for model in [backbone, transform_net, metanet]:
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
trainable_params[name] = param
|
|
trainable_param_shapes[name] = param.shape
|
|
|
|
# 开始训练
|
|
optimizer = optim.Adam(trainable_params.values(), 1e-3)
|
|
|
|
n_batch = len(content_data_loader)
|
|
metanet.train()
|
|
transform_net.train()
|
|
|
|
for epoch in range(epochs):
|
|
smoother = defaultdict(Smooth)
|
|
with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:
|
|
for batch, (content_images, _) in pbar:
|
|
# 当前 batch 的大小
|
|
size = content_images.size()[0]
|
|
n_iter = epoch * n_batch + batch
|
|
|
|
# 每 20 个 batch 随机挑选一张新的风格图像,计算其特征
|
|
if batch % 20 == 0:
|
|
style_image = random.choice(style_dataset)[0]
|
|
style_image_tensor = style_image.unsqueeze(0).to(device)
|
|
style_features = backbone(style_image_tensor)
|
|
style_mean_std = mean_std(style_features)
|
|
# gram
|
|
style_grams = [gram_matrix(x) for x in backbone(torch.stack((style_image,) * batch_size))]
|
|
|
|
# batch 末尾不足 batch_size 时按 size 算
|
|
if size != batch_size:
|
|
style_grams = [gram_matrix(x) for x in backbone(torch.stack((style_image,) * size))]
|
|
|
|
# 检查纯色
|
|
x = content_images.cpu().numpy()
|
|
if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():
|
|
continue
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# 使用风格图像生成风格模型
|
|
weights = metanet(mean_std(style_features))
|
|
transform_net.set_weights(weights, 0)
|
|
|
|
# 使用风格模型预测风格迁移图像
|
|
content_images = content_images.to(device)
|
|
transformed_images = transform_net(content_images)
|
|
|
|
# 使用 vgg16 计算特征
|
|
content_features = backbone(content_images)
|
|
transformed_features = backbone(transformed_images)
|
|
transformed_mean_std = mean_std(transformed_features)
|
|
|
|
# content loss
|
|
content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])
|
|
|
|
# style loss
|
|
if ModelConfig.gram:
|
|
# gram
|
|
style_loss = 0
|
|
transformed_grams = [gram_matrix(x) for x in transformed_features]
|
|
for a, b in zip(transformed_grams, style_grams):
|
|
style_loss += F.mse_loss(a, b) * style_weight
|
|
style_loss /= size
|
|
else:
|
|
style_loss = style_weight * F.mse_loss(transformed_mean_std,
|
|
style_mean_std.expand_as(transformed_mean_std))
|
|
|
|
# total variation loss
|
|
y = transformed_images
|
|
tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
|
|
torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))
|
|
|
|
# 求和
|
|
loss = content_loss + style_loss + tv_loss
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
smoother['content_loss'] += content_loss.item()
|
|
smoother['style_loss'] += style_loss.item()
|
|
smoother['tv_loss'] += tv_loss.item()
|
|
smoother['loss'] += loss.item()
|
|
|
|
max_value = max([x.max().item() for x in weights.values()])
|
|
|
|
writer.add_scalar('loss/loss', loss, n_iter)
|
|
writer.add_scalar('loss/content_loss', content_loss, n_iter)
|
|
writer.add_scalar('loss/style_loss', style_loss, n_iter)
|
|
writer.add_scalar('loss/total_variation', tv_loss, n_iter)
|
|
writer.add_scalar('loss/max', max_value, n_iter)
|
|
|
|
s = 'Epoch: {} '.format(epoch + 1)
|
|
s += 'Content: {:.2f} '.format(smoother['content_loss'])
|
|
s += 'Style: {:.2f} '.format(smoother['style_loss'])
|
|
s += 'TV: {:.2f} '.format(smoother['tv_loss'])
|
|
s += 'Loss: {:.2f} '.format(smoother['loss'])
|
|
s += 'Max: {:.2f}'.format(max_value)
|
|
|
|
# if (batch + 1) % verbose_image_batch == 0:
|
|
# transform_net.eval()
|
|
# visualization_transformed_images = transform_net(visualization_content_images)
|
|
# transform_net.train()
|
|
# visualization_transformed_images = torch.cat([style_image, visualization_transformed_images])
|
|
# writer.add_images('debug', recover_tensor(visualization_transformed_images), n_iter)
|
|
# del visualization_transformed_images
|
|
|
|
if (batch + 1) % verbose_hist_batch == 0:
|
|
for name, param in weights.items():
|
|
writer.add_histogram('transform_net.' + name, param.clone().cpu().data.numpy(),
|
|
n_iter, bins='auto')
|
|
|
|
for name, param in transform_net.named_parameters():
|
|
writer.add_histogram('transform_net.' + name, param.clone().cpu().data.numpy(),
|
|
n_iter, bins='auto')
|
|
|
|
for name, param in metanet.named_parameters():
|
|
l = name.split('.')
|
|
l.remove(l[-1])
|
|
writer.add_histogram('metanet.' + '.'.join(l), param.clone().cpu().data.numpy(),
|
|
n_iter, bins='auto')
|
|
|
|
pbar.set_description(s)
|
|
|
|
del transformed_images, weights
|
|
|
|
torch.save(metanet.state_dict(), 'checkpoints/{}_{}.pth'.format(model_name, epoch + 1))
|
|
torch.save(transform_net.state_dict(),
|
|
'checkpoints/{}_transform_net_{}.pth'.format(model_name, epoch + 1))
|
|
|
|
torch.save(metanet.state_dict(), 'models/{}.pth'.format(model_name))
|
|
torch.save(transform_net.state_dict(), 'models/{}_transform_net.pth'.format(model_name))
|