import numpy as np import mindspore # import moxing as mox from mindspore import numpy from PIL import Image from mindspore import ops import os import cv2 as cv import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as c_version import mindspore.dataset.vision.py_transforms as py_vision from mindspore import Model, context, nn, Tensor, Parameter, load_checkpoint def save_img(i, optimizer, output_path): if not os.path.exists(output_path): os.mkdir(output_path) final_img = optimizer.parameters[0].asnumpy() final_img = final_img.squeeze(axis=0) final_img = np.moveaxis(final_img, 0, 2) dump_img = np.copy(final_img) dump_img += np.array([123.675, 116.28, 103.53]).reshape((1, 1, 3)) dump_img = np.clip(dump_img, 0, 255).astype('uint8') # dump_img = cv.resize(dump_img, (224, 224), interpolation=cv.INTER_CUBIC) img_path = output_path+"/"+"iter_"+str(i)+".jpg" # imgput_path = ./output_path/content2_to_sty3/lr=0.5/iter_1 cv.imwrite(img_path, dump_img[:, :, ::-1]) # mox.file.copy_parallel(img_path, args.train_url+img_path[12:]) def create_dataset(img): """生成数据集""" dataset = ds.NumpySlicesDataset(data=img, column_names=['data']) return dataset def gram_matrix(x, should_normalize=True): """ Generate gram matrices of the representations of content and style images. """ # 对网络的特征进行矩阵编码 b, ch, h, w = x.shape # x的形状 features = x.view(b, ch, w * h) # 将x降维 transpose = ops.Transpose() batmatmul = ops.BatchMatMul(transpose_a=False) features_t = transpose(features, (0, 2, 1)) gram = batmatmul(features, features_t) # gram 为矩阵相乘计算得新图片的像素 if should_normalize: # 标准化 gram /= ch * h * w return gram def load_image(img_path, target_shape=None): # 图像预处理 返回 1 * 3 * 400 * x if not os.path.exists(img_path): raise Exception(f'Path not found: {img_path}') img = cv.imread(img_path)[:, :, ::-1] # convert BGR to RGB when reading if target_shape is not None: if isinstance(target_shape, int) and target_shape != -1: current_height, current_width = img.shape[:2] new_height = target_shape new_width = int(current_width * (new_height / current_height)) img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC) else: img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC) img = img.astype(np.float32) to_tensor = py_vision.ToTensor() # channel conversion and pixel value normalization normalize = c_version.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1]) img = normalize(img) # (400, 533, 3) img = to_tensor(img) * 225 # (3, 400, 533) img = np.expand_dims(img, axis=0) # img /= 255 # transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Lambda(lambda x: x.mul(255)), # 乘255 # transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1]) # ]) # img = transform(img).unsqueeze(0) # img = img.numpy() return img class Optim_Loss(nn.Cell): def __init__(self, net, target_maps): super(Optim_Loss, self).__init__() self.net = net self.target_maps = target_maps[:-1] self.weight = [100000.0, 30000.0, 1.0] self.get_style_loss = nn.MSELoss(reduction='sum') self.get_content_loss = nn.MSELoss(reduction='mean') self.cast = ops.Cast() # 转换为 mindspore.tensor self.ct = target_maps[2] def construct(self): optimize_img = self.ct current_maps = self.net(self.cast(optimize_img, mindspore.float32)) # 6个特征图 # 当前图片的特征 current_content_maps = current_maps[4].squeeze(axis=0) # 内容特征 # 4_2的内容map for i in range(len(current_maps)): # 0, 1, 2, 3, 4, 5 1, 2, 3, 4, 4_2, 5 if i != 4: current_maps[i] = gram_matrix(current_maps[i]) target_content_maps = self.target_maps[0] # 任务的内容特征 target_content_gram = self.target_maps[1] # 任务的风格特征 content_loss = self.get_content_loss(current_content_maps, target_content_maps) style_loss = 0 for j in range(6): if j == 5: style_loss += self.get_style_loss(current_maps[j], target_content_gram[j-1]) if j < 4: style_loss += self.get_style_loss(current_maps[j], target_content_gram[j]) style_loss /= 5 tv_loss = numpy.sum(numpy.abs(optimize_img[:, :, :, :-1] - optimize_img[:, :, :, 1:])) \ + numpy.sum(numpy.abs(optimize_img[:, :, :-1, :] - optimize_img[:, :, 1:, :])) total_loss = content_loss * self.weight[0] + style_loss * self.weight[1] + tv_loss * self.weight[2] return total_loss/130001 def load_parameters(file_name): param_dict = load_checkpoint(file_name) param_dict_new = {} # print(param_dict) for key, values in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith("layers."): param_dict_new['l'+key[7:]] = values else: param_dict_new[key] = values return param_dict_new