# coding=utf-8 import os # import moxing as mox import numpy as np import cv2 as cv import mindspore from mindspore import nn, Tensor, Parameter, context from mindspore import load_checkpoint, load_param_into_net, save_checkpoint from vgg import Vgg19 from method import create_dataset, load_image, gram_matrix, Optim_Loss, save_img, load_parameters # 设置为动态图模式 import argparse # parser = argparse.ArgumentParser(description='Training img transfer') # parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU']) # parser.add_argument('--data_url', required=True, default=None) # parser.add_argument('--train_url', required=True, default=None) # # parser.add_argument("model_url", type=str, default=None, help="pretrained checkpoint file path of vgg.") # args = parser.parse_known_args()[0] # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) # if not os.path.exists("./img_data"): # os.mkdir("./img_data") # mox.file.copy_parallel(args.data_url, "./img_data") # # if not os.path.exists('./output_img'): # os.mkdir("./output_img") vgg = Vgg19() param_dict = load_parameters("vgg19.ckpt") print("ckpt load success") load_param_into_net(vgg, param_dict) print("param load success") vgg.set_train(False) # 评估模式 def load_img_ckpt(path): return load_parameters(path) def optimize_picture(content_path, style_path, iter_num, learning_rate, output_path, mode): if mode == 0: optimize_img = load_image(content_path, 400) content_maps = vgg(Tensor(optimize_img)) optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img") else: optimize_img = load_img_ckpt(content_path)["img"] content_maps = vgg(Tensor(optimize_img)) optimizing_param = Parameter(Tensor(optimize_img, mindspore.float32), "optimizing_img") style_img = load_image(style_path, 400) # 特征图提取(6层) content_maps = content_maps[4].squeeze(axis=0) # 内容图片vgg19提取的特征 style_maps = list(vgg(Tensor(style_img))) # 风格图片的特征 style_maps.pop(4) # 剔除掉不需要的map style_gram = list(gram_matrix(x) for x in style_maps) # 要优化的图片 num_of_iterations = iter_num + 1 # 迭代的次数 loss_opt = Optim_Loss(vgg, [content_maps, style_gram, optimizing_param]) cosine_decay_lr = nn.CosineDecayLR(0.05, 2.0, 2000) optimizer_n = nn.Adam([optimizing_param], learning_rate=cosine_decay_lr) # 损失优化方法 # optimizer = nn.Adam([optimizing_param], learning_rate=learning_rate) # 损失优化方法 train_net = nn.TrainOneStepCell(loss_opt, optimizer_n) train_net.set_train() for i in range(num_of_iterations): if i % 50 == 0: save_img(i, optimizer_n, output_path) loss = train_net() print(str(content_path[-13:-4]) + "_to_" + str(style_path[-8:-4]) + ":iteration_", str(i), ": loss", loss) final = optimizer_n.parameters save_checkpoint([{"name": "img", "data": final[0]}], output_path+"/img_1.ckpt") optimize_picture("output_img/iter_19/img.ckpt", "transfer_img/trans_6/sty_fam_1.jpg", 2000, 0.05, "output_img/iter_19", 1) # img = cv.imread("transfer_img/trans_5/style_img.jpg")[:, :, ::-1] # print(img.shape)