|
|
from torchvision import utils as vutils
|
|
|
|
|
|
from models import *
|
|
|
from utils import *
|
|
|
|
|
|
CORE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
device = torch.device(CORE)
|
|
|
|
|
|
transform_net = TransformNet(32).to(device)
|
|
|
transform_net.load_state_dict(
|
|
|
torch.load('./model/Picasso.pth', map_location=CORE))
|
|
|
|
|
|
style_img = read_image('./style/style-Picasso.jpg').to(device)
|
|
|
|
|
|
content_img = read_image('./img/content.jpg').to(device)
|
|
|
output_img = transform_net(content_img)
|
|
|
|
|
|
plt.figure(figsize=(18, 6))
|
|
|
|
|
|
plt.subplot(1, 3, 1)
|
|
|
imshow(style_img, title='Style Image')
|
|
|
|
|
|
plt.subplot(1, 3, 2)
|
|
|
imshow(content_img, title='Content Image')
|
|
|
|
|
|
plt.subplot(1, 3, 3)
|
|
|
imshow(output_img.detach(), title='Output Image')
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
# def save_image_tensor2cv2(input_tensor: torch.Tensor, filename):
|
|
|
# """
|
|
|
# 将tensor保存为cv2格式
|
|
|
# :param input_tensor: 要保存的tensor
|
|
|
# :param filename: 保存的文件名
|
|
|
# """
|
|
|
# assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1)
|
|
|
# # 复制一份
|
|
|
# input_tensor = input_tensor.clone().detach()
|
|
|
# # 到cpu
|
|
|
# input_tensor = input_tensor.to(torch.device('cpu'))
|
|
|
# # 反归一化
|
|
|
# # input_tensor = unnormalize(input_tensor)
|
|
|
# # 去掉批次维度
|
|
|
# input_tensor = input_tensor.squeeze()
|
|
|
# # 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为cv2
|
|
|
# input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
|
|
|
# # RGB转BRG
|
|
|
# input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)
|
|
|
# cv2.imwrite(filename, input_tensor)
|
|
|
|
|
|
vutils.save_image(output_img.detach(), './result/Picasso.jpg')
|