from torchvision import utils as vutils from models import * from utils import * CORE = "cuda:0" if torch.cuda.is_available() else "cpu" device = torch.device(CORE) transform_net = TransformNet(32).to(device) transform_net.load_state_dict( torch.load('./model/Picasso.pth', map_location=CORE)) style_img = read_image('./style/style-Picasso.jpg').to(device) content_img = read_image('./img/content.jpg').to(device) output_img = transform_net(content_img) plt.figure(figsize=(18, 6)) plt.subplot(1, 3, 1) imshow(style_img, title='Style Image') plt.subplot(1, 3, 2) imshow(content_img, title='Content Image') plt.subplot(1, 3, 3) imshow(output_img.detach(), title='Output Image') plt.show() # def save_image_tensor2cv2(input_tensor: torch.Tensor, filename): # """ # 将tensor保存为cv2格式 # :param input_tensor: 要保存的tensor # :param filename: 保存的文件名 # """ # assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1) # # 复制一份 # input_tensor = input_tensor.clone().detach() # # 到cpu # input_tensor = input_tensor.to(torch.device('cpu')) # # 反归一化 # # input_tensor = unnormalize(input_tensor) # # 去掉批次维度 # input_tensor = input_tensor.squeeze() # # 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为cv2 # input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy() # # RGB转BRG # input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR) # cv2.imwrite(filename, input_tensor) vutils.save_image(output_img.detach(), './result/Picasso.jpg')