You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.3 KiB

import numpy as np
import mindspore
# import moxing as mox
from mindspore import numpy
from PIL import Image
from mindspore import ops
import os
import cv2 as cv
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_version
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import Model, context, nn, Tensor, Parameter, load_checkpoint
def save_img(i, optimizer, output_path):
if not os.path.exists(output_path):
os.mkdir(output_path)
final_img = optimizer.parameters[0].asnumpy()
final_img = final_img.squeeze(axis=0)
final_img = np.moveaxis(final_img, 0, 2)
dump_img = np.copy(final_img)
dump_img += np.array([123.675, 116.28, 103.53]).reshape((1, 1, 3))
dump_img = np.clip(dump_img, 0, 255).astype('uint8')
# dump_img = cv.resize(dump_img, (224, 224), interpolation=cv.INTER_CUBIC)
img_path = output_path+"/"+"iter_"+str(i)+".jpg" # imgput_path = ./output_path/content2_to_sty3/lr=0.5/iter_1
cv.imwrite(img_path, dump_img[:, :, ::-1])
# mox.file.copy_parallel(img_path, args.train_url+img_path[12:])
def create_dataset(img):
"""生成数据集"""
dataset = ds.NumpySlicesDataset(data=img, column_names=['data'])
return dataset
def gram_matrix(x, should_normalize=True):
"""
Generate gram matrices of the representations of content and style images.
"""
# 对网络的特征进行矩阵编码
b, ch, h, w = x.shape # x的形状
features = x.view(b, ch, w * h) # 将x降维
transpose = ops.Transpose()
batmatmul = ops.BatchMatMul(transpose_a=False)
features_t = transpose(features, (0, 2, 1))
gram = batmatmul(features, features_t) # gram 为矩阵相乘计算得新图片的像素
if should_normalize: # 标准化
gram /= ch * h * w
return gram
def load_image(img_path, target_shape=None):
# 图像预处理 返回 1 * 3 * 400 * x
if not os.path.exists(img_path):
raise Exception(f'Path not found: {img_path}')
img = cv.imread(img_path)[:, :, ::-1] # convert BGR to RGB when reading
if target_shape is not None:
if isinstance(target_shape, int) and target_shape != -1:
current_height, current_width = img.shape[:2]
new_height = target_shape
new_width = int(current_width * (new_height / current_height))
img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
else:
img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
img = img.astype(np.float32)
to_tensor = py_vision.ToTensor() # channel conversion and pixel value normalization
normalize = c_version.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1])
img = normalize(img) # <class 'numpy.ndarray'> (400, 533, 3)
img = to_tensor(img) * 225 # <class 'numpy.ndarray'> (3, 400, 533)
img = np.expand_dims(img, axis=0)
# img /= 255
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Lambda(lambda x: x.mul(255)), # 乘255
# transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1])
# ])
# img = transform(img).unsqueeze(0)
# img = img.numpy()
return img
class Optim_Loss(nn.Cell):
def __init__(self, net, target_maps):
super(Optim_Loss, self).__init__()
self.net = net
self.target_maps = target_maps[:-1]
self.weight = [100000.0, 30000.0, 1.0]
self.get_style_loss = nn.MSELoss(reduction='sum')
self.get_content_loss = nn.MSELoss(reduction='mean')
self.cast = ops.Cast() # 转换为 mindspore.tensor
self.ct = target_maps[2]
def construct(self):
optimize_img = self.ct
current_maps = self.net(self.cast(optimize_img, mindspore.float32)) # 6个特征图
# 当前图片的特征
current_content_maps = current_maps[4].squeeze(axis=0) # 内容特征 # 4_2的内容map
for i in range(len(current_maps)): # 0, 1, 2, 3, 4, 5 1, 2, 3, 4, 4_2, 5
if i != 4:
current_maps[i] = gram_matrix(current_maps[i])
target_content_maps = self.target_maps[0] # 任务的内容特征
target_content_gram = self.target_maps[1] # 任务的风格特征
content_loss = self.get_content_loss(current_content_maps, target_content_maps)
style_loss = 0
for j in range(6):
if j == 5:
style_loss += self.get_style_loss(current_maps[j], target_content_gram[j-1])
if j < 4:
style_loss += self.get_style_loss(current_maps[j], target_content_gram[j])
style_loss /= 5
tv_loss = numpy.sum(numpy.abs(optimize_img[:, :, :, :-1] - optimize_img[:, :, :, 1:])) \
+ numpy.sum(numpy.abs(optimize_img[:, :, :-1, :] - optimize_img[:, :, 1:, :]))
total_loss = content_loss * self.weight[0] + style_loss * self.weight[1] + tv_loss * self.weight[2]
return total_loss/130001
def load_parameters(file_name):
param_dict = load_checkpoint(file_name)
param_dict_new = {}
# print(param_dict)
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith("layers."):
param_dict_new['l'+key[7:]] = values
else:
param_dict_new[key] = values
return param_dict_new