|
|
# coding=utf-8
|
|
|
import os
|
|
|
# import moxing as mox
|
|
|
import numpy as np
|
|
|
import cv2 as cv
|
|
|
import mindspore
|
|
|
from mindspore import nn, Tensor, Parameter, context
|
|
|
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
|
|
|
from vgg import Vgg19
|
|
|
from method import create_dataset, load_image, gram_matrix, Optim_Loss, save_img, load_parameters
|
|
|
# 设置为动态图模式
|
|
|
import argparse
|
|
|
# parser = argparse.ArgumentParser(description='Training img transfer')
|
|
|
# parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'])
|
|
|
# parser.add_argument('--data_url', required=True, default=None)
|
|
|
# parser.add_argument('--train_url', required=True, default=None)
|
|
|
# # parser.add_argument("model_url", type=str, default=None, help="pretrained checkpoint file path of vgg.")
|
|
|
# args = parser.parse_known_args()[0]
|
|
|
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
|
|
|
|
# if not os.path.exists("./img_data"):
|
|
|
# os.mkdir("./img_data")
|
|
|
# mox.file.copy_parallel(args.data_url, "./img_data")
|
|
|
#
|
|
|
# if not os.path.exists('./output_img'):
|
|
|
# os.mkdir("./output_img")
|
|
|
|
|
|
vgg = Vgg19()
|
|
|
param_dict = load_parameters("vgg19.ckpt")
|
|
|
print("ckpt load success")
|
|
|
load_param_into_net(vgg, param_dict)
|
|
|
print("param load success")
|
|
|
vgg.set_train(False) # 评估模式
|
|
|
|
|
|
|
|
|
def load_img_ckpt(path):
|
|
|
return load_parameters(path)
|
|
|
|
|
|
|
|
|
def optimize_picture(content_path, style_path, iter_num, learning_rate, output_path, mode):
|
|
|
if mode == 0:
|
|
|
optimize_img = load_image(content_path, 400)
|
|
|
content_maps = vgg(Tensor(optimize_img))
|
|
|
optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img")
|
|
|
else:
|
|
|
optimize_img = load_img_ckpt(content_path)["img"]
|
|
|
content_maps = vgg(Tensor(optimize_img))
|
|
|
optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img")
|
|
|
style_img = load_image(style_path, 400)
|
|
|
# 特征图提取(6层)
|
|
|
content_maps = content_maps[4].squeeze(axis=0) # 内容图片vgg19提取的特征
|
|
|
style_maps = list(vgg(Tensor(style_img))) # 风格图片的特征
|
|
|
style_maps.pop(4) # 剔除掉不需要的map
|
|
|
style_gram = list(gram_matrix(x) for x in style_maps)
|
|
|
# 要优化的图片
|
|
|
|
|
|
num_of_iterations = iter_num + 1 # 迭代的次数
|
|
|
loss_opt = Optim_Loss(vgg, [content_maps, style_gram, optimizing_param])
|
|
|
cosine_decay_lr = nn.CosineDecayLR(0.05, 2.0, 2000)
|
|
|
optimizer_n = nn.Adam([optimizing_param], learning_rate=cosine_decay_lr) # 损失优化方法
|
|
|
# optimizer = nn.Adam([optimizing_param], learning_rate=learning_rate) # 损失优化方法
|
|
|
train_net = nn.TrainOneStepCell(loss_opt, optimizer_n)
|
|
|
train_net.set_train()
|
|
|
for i in range(num_of_iterations):
|
|
|
if i % 50 == 0:
|
|
|
save_img(i, optimizer_n, output_path)
|
|
|
loss = train_net()
|
|
|
print(str(content_path[-13:-4]) + "_to_" + str(style_path[-8:-4]) + ":iteration_", str(i), ": loss", loss)
|
|
|
final = optimizer_n.parameters
|
|
|
save_checkpoint([{"name": "img", "data": final[0]}], output_path+"/img_1.ckpt")
|
|
|
|
|
|
|
|
|
optimize_picture("output_img/iter_19/img.ckpt",
|
|
|
"transfer_img/trans_6/sty_fam_1.jpg", 2000, 0.05, "output_img/iter_19", 1)
|
|
|
|
|
|
# img = cv.imread("transfer_img/trans_5/style_img.jpg")[:, :, ::-1]
|
|
|
# print(img.shape)
|