You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
3.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
import os
# import moxing as mox
import numpy as np
import cv2 as cv
import mindspore
from mindspore import nn, Tensor, Parameter, context
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
from vgg import Vgg19
from method import create_dataset, load_image, gram_matrix, Optim_Loss, save_img, load_parameters
# 设置为动态图模式
import argparse
# parser = argparse.ArgumentParser(description='Training img transfer')
# parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'])
# parser.add_argument('--data_url', required=True, default=None)
# parser.add_argument('--train_url', required=True, default=None)
# # parser.add_argument("model_url", type=str, default=None, help="pretrained checkpoint file path of vgg.")
# args = parser.parse_known_args()[0]
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# if not os.path.exists("./img_data"):
# os.mkdir("./img_data")
# mox.file.copy_parallel(args.data_url, "./img_data")
#
# if not os.path.exists('./output_img'):
# os.mkdir("./output_img")
vgg = Vgg19()
param_dict = load_parameters("vgg19.ckpt")
print("ckpt load success")
load_param_into_net(vgg, param_dict)
print("param load success")
vgg.set_train(False) # 评估模式
def load_img_ckpt(path):
return load_parameters(path)
def optimize_picture(content_path, style_path, iter_num, learning_rate, output_path, mode):
if mode == 0:
optimize_img = load_image(content_path, 400)
content_maps = vgg(Tensor(optimize_img))
optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img")
else:
optimize_img = load_img_ckpt(content_path)["img"]
content_maps = vgg(Tensor(optimize_img))
optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img")
style_img = load_image(style_path, 400)
# 特征图提取6层
content_maps = content_maps[4].squeeze(axis=0) # 内容图片vgg19提取的特征
style_maps = list(vgg(Tensor(style_img))) # 风格图片的特征
style_maps.pop(4) # 剔除掉不需要的map
style_gram = list(gram_matrix(x) for x in style_maps)
# 要优化的图片
num_of_iterations = iter_num + 1 # 迭代的次数
loss_opt = Optim_Loss(vgg, [content_maps, style_gram, optimizing_param])
cosine_decay_lr = nn.CosineDecayLR(0.05, 2.0, 2000)
optimizer_n = nn.Adam([optimizing_param], learning_rate=cosine_decay_lr) # 损失优化方法
# optimizer = nn.Adam([optimizing_param], learning_rate=learning_rate) # 损失优化方法
train_net = nn.TrainOneStepCell(loss_opt, optimizer_n)
train_net.set_train()
for i in range(num_of_iterations):
if i % 50 == 0:
save_img(i, optimizer_n, output_path)
loss = train_net()
print(str(content_path[-13:-4]) + "_to_" + str(style_path[-8:-4]) + ":iteration_", str(i), ": loss", loss)
final = optimizer_n.parameters
save_checkpoint([{"name": "img", "data": final[0]}], output_path+"/img_1.ckpt")
optimize_picture("output_img/iter_19/img.ckpt",
"transfer_img/trans_6/sty_fam_1.jpg", 2000, 0.05, "output_img/iter_19", 1)
# img = cv.imread("transfer_img/trans_5/style_img.jpg")[:, :, ::-1]
# print(img.shape)