|
|
|
|
#--------------------------导包
|
|
|
|
|
import torch
|
|
|
|
|
from torch.autograd import Variable
|
|
|
|
|
import torchvision
|
|
|
|
|
from torchvision import transforms, models
|
|
|
|
|
import copy
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
#--------------------------1.数据预处理,加载数据
|
|
|
|
|
transform = transforms.Compose([transforms.Resize([224,224]),
|
|
|
|
|
transforms.ToTensor()])
|
|
|
|
|
|
|
|
|
|
def loadimg(path = None):
|
|
|
|
|
img = Image.open(path)
|
|
|
|
|
img = transform(img)
|
|
|
|
|
img = img.unsqueeze(0)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
content_img = loadimg('images/1.jpg') #入参是自己存放图片的位置
|
|
|
|
|
content_img = Variable(content_img).cuda()
|
|
|
|
|
style_img = loadimg('images/2.jpg')
|
|
|
|
|
style_img = Variable(style_img).cuda()
|
|
|
|
|
|
|
|
|
|
#--------------------------2.定义内容损失和风格损失
|
|
|
|
|
class Content_loss(torch.nn.Module):
|
|
|
|
|
def __init__(self,weight,target):
|
|
|
|
|
super(Content_loss,self).__init__()
|
|
|
|
|
self.weight = weight
|
|
|
|
|
self.target = target.detach()*weight
|
|
|
|
|
self.loss_fn = torch.nn.MSELoss()
|
|
|
|
|
|
|
|
|
|
def forward(self,in_put):
|
|
|
|
|
self.loss = self.loss_fn(in_put*self.weight,self.target)
|
|
|
|
|
return in_put
|
|
|
|
|
|
|
|
|
|
def backward(self):
|
|
|
|
|
self.loss.backward(retain_graph = True)
|
|
|
|
|
return self.loss
|
|
|
|
|
|
|
|
|
|
class Gram_matrix(torch.nn.Module):
|
|
|
|
|
def forward(self,in_put):
|
|
|
|
|
a,b,c,d = in_put.size()
|
|
|
|
|
feature = in_put.view(a*b,c*d)
|
|
|
|
|
gram = torch.mm(feature,feature.t())
|
|
|
|
|
return gram.div(a*b*c*d)
|
|
|
|
|
|
|
|
|
|
class Style_loss(torch.nn.Module):
|
|
|
|
|
def __init__(self,weight,target):
|
|
|
|
|
super(Style_loss,self).__init__()
|
|
|
|
|
self.weight = weight
|
|
|
|
|
self.target = target.detach()*weight
|
|
|
|
|
self.loss_fn = torch.nn.MSELoss()
|
|
|
|
|
self.gram = Gram_matrix()
|
|
|
|
|
|
|
|
|
|
def forward(self,in_put):
|
|
|
|
|
self.Gram = self.gram(in_put.clone())
|
|
|
|
|
self.Gram.mul_(self.weight)
|
|
|
|
|
self.loss = self.loss_fn(self.Gram,self.target)
|
|
|
|
|
return in_put
|
|
|
|
|
|
|
|
|
|
def backward(self):
|
|
|
|
|
self.loss.backward(retain_graph = True)
|
|
|
|
|
return self.loss
|
|
|
|
|
|
|
|
|
|
#--------------------------3.模型搭建
|
|
|
|
|
cnn = models.vgg16(pretrained = True).features #迁移VGG16架构的特征提取部分
|
|
|
|
|
#指定整个卷积过程中分别在哪一层提取内容和风格
|
|
|
|
|
content_layer = ["Conv_3"]
|
|
|
|
|
style_layer = ["Conv_1","Conv_2","Conv_3","Conv_4"]
|
|
|
|
|
#定义保存内容损失和风格损失的列表
|
|
|
|
|
content_losses = []
|
|
|
|
|
style_losses = []
|
|
|
|
|
#指定内容损失和风格损失对最后得到的融合图片的影响权重
|
|
|
|
|
content_weight = 1
|
|
|
|
|
style_weight = 1000
|
|
|
|
|
|
|
|
|
|
#搭建图像风格迁移模型的代码如下:
|
|
|
|
|
new_model = torch.nn.Sequential() #建立空的模型
|
|
|
|
|
model = copy.deepcopy(cnn)
|
|
|
|
|
#deepcopy深复制,将被复制对象完全再复制一遍作为独立的新个体单独存在,改变原有被复制对象不会对已经复制出来的新对象产生影响。
|
|
|
|
|
#copy浅复制,并不会产生一个独立的对象单独存在,他只是将原有的数据块打上一个新标签
|
|
|
|
|
#所以当其中一个标签被改变的时候,数据块就会发生变化,另一个标签也会随之改变。
|
|
|
|
|
gram = Gram_matrix()
|
|
|
|
|
|
|
|
|
|
use_gpu = torch.cuda.is_available()
|
|
|
|
|
if use_gpu:
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
new_model = new_model.cuda()
|
|
|
|
|
gram = gram.cuda()
|
|
|
|
|
|
|
|
|
|
index = 1
|
|
|
|
|
#只使用迁移模型特征提取部分的前8层
|
|
|
|
|
for layer in list(model)[:8]:
|
|
|
|
|
if isinstance(layer,torch.nn.Conv2d):
|
|
|
|
|
name = "Conv_" + str(index)
|
|
|
|
|
#使用add_module方法向空的模型加入指定的层次模块
|
|
|
|
|
new_model.add_module(name,layer)
|
|
|
|
|
if name in content_layer:
|
|
|
|
|
target = new_model(content_img).clone()
|
|
|
|
|
content_loss = Content_loss(content_weight,target)
|
|
|
|
|
new_model.add_module("content_loss_"+str(index),content_loss)
|
|
|
|
|
content_losses.append(content_loss)
|
|
|
|
|
|
|
|
|
|
if name in style_layer:
|
|
|
|
|
target = new_model(style_img).clone()
|
|
|
|
|
target = gram(target)
|
|
|
|
|
style_loss = Style_loss(style_weight,target)
|
|
|
|
|
new_model.add_module("style_loss_"+str(index),style_loss)
|
|
|
|
|
style_losses.append(style_loss)
|
|
|
|
|
|
|
|
|
|
if isinstance(layer,torch.nn.ReLU):
|
|
|
|
|
name = "ReLU_"+str(index)
|
|
|
|
|
new_model.add_module(name,layer)
|
|
|
|
|
index = index + 1
|
|
|
|
|
|
|
|
|
|
if isinstance(layer,torch.nn.MaxPool2d):
|
|
|
|
|
name = "MaxPool_"+str(index)
|
|
|
|
|
new_model.add_module(name,layer)
|
|
|
|
|
#构造优化器
|
|
|
|
|
input_img = content_img.clone()
|
|
|
|
|
parameter = torch.nn.Parameter(input_img.data)
|
|
|
|
|
optimizer = torch.optim.LBFGS([parameter])
|
|
|
|
|
|
|
|
|
|
#--------------------------4.模型训练和参数优化
|
|
|
|
|
epoch_n = 300
|
|
|
|
|
epoch = [0]
|
|
|
|
|
while epoch[0] <= epoch_n:
|
|
|
|
|
|
|
|
|
|
def closure():
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
style_score = 0
|
|
|
|
|
content_score = 0
|
|
|
|
|
parameter.data.clamp_(0,1)
|
|
|
|
|
new_model(parameter)
|
|
|
|
|
for sl in style_losses:
|
|
|
|
|
style_score += sl.backward()
|
|
|
|
|
|
|
|
|
|
for cl in content_losses:
|
|
|
|
|
content_score += cl.backward()
|
|
|
|
|
|
|
|
|
|
epoch[0] += 1
|
|
|
|
|
if epoch[0] % 50 == 0:
|
|
|
|
|
print('Epoch:{} Style_loss: {:4f} Content_loss: {:.4f}'.format(epoch[0], style_score.data.item(),
|
|
|
|
|
content_score.data.item()))
|
|
|
|
|
return style_score + content_score
|
|
|
|
|
|
|
|
|
|
optimizer.step(closure)
|
|
|
|
|
|
|
|
|
|
#--------------------------5.对风格迁移图片输出
|
|
|
|
|
output = parameter.data
|
|
|
|
|
unloader = transforms.ToPILImage() # 重新转化成PIL图像格式
|
|
|
|
|
|
|
|
|
|
plt.ion()
|
|
|
|
|
plt.figure()
|
|
|
|
|
def imshow(tensor, title=None):
|
|
|
|
|
image = tensor.clone().cpu() # 克隆tensor,改变时不影响被克隆的tensor
|
|
|
|
|
image = image.view(3, 224, 224) # 转换维度
|
|
|
|
|
image = unloader(image)
|
|
|
|
|
plt.imshow(image)
|
|
|
|
|
if title is not None:
|
|
|
|
|
plt.title(title)
|
|
|
|
|
plt.pause(0.001) # 稍作停顿,以便更新图表
|
|
|
|
|
imshow(output, title='Output Image')
|
|
|
|
|
|
|
|
|
|
# 设置sphinx_gallery_thumbnail_number = 4
|
|
|
|
|
plt.ioff()
|
|
|
|
|
plt.show()
|