#--------------------------导包 import torch from torch.autograd import Variable import torchvision from torchvision import transforms, models import copy from PIL import Image import matplotlib.pyplot as plt #--------------------------1.数据预处理,加载数据 transform = transforms.Compose([transforms.Resize([224,224]), transforms.ToTensor()]) def loadimg(path = None): img = Image.open(path) img = transform(img) img = img.unsqueeze(0) return img content_img = loadimg('images/1.jpg') #入参是自己存放图片的位置 content_img = Variable(content_img).cuda() style_img = loadimg('images/2.jpg') style_img = Variable(style_img).cuda() #--------------------------2.定义内容损失和风格损失 class Content_loss(torch.nn.Module): def __init__(self,weight,target): super(Content_loss,self).__init__() self.weight = weight self.target = target.detach()*weight self.loss_fn = torch.nn.MSELoss() def forward(self,in_put): self.loss = self.loss_fn(in_put*self.weight,self.target) return in_put def backward(self): self.loss.backward(retain_graph = True) return self.loss class Gram_matrix(torch.nn.Module): def forward(self,in_put): a,b,c,d = in_put.size() feature = in_put.view(a*b,c*d) gram = torch.mm(feature,feature.t()) return gram.div(a*b*c*d) class Style_loss(torch.nn.Module): def __init__(self,weight,target): super(Style_loss,self).__init__() self.weight = weight self.target = target.detach()*weight self.loss_fn = torch.nn.MSELoss() self.gram = Gram_matrix() def forward(self,in_put): self.Gram = self.gram(in_put.clone()) self.Gram.mul_(self.weight) self.loss = self.loss_fn(self.Gram,self.target) return in_put def backward(self): self.loss.backward(retain_graph = True) return self.loss #--------------------------3.模型搭建 cnn = models.vgg16(pretrained = True).features #迁移VGG16架构的特征提取部分 #指定整个卷积过程中分别在哪一层提取内容和风格 content_layer = ["Conv_3"] style_layer = ["Conv_1","Conv_2","Conv_3","Conv_4"] #定义保存内容损失和风格损失的列表 content_losses = [] style_losses = [] #指定内容损失和风格损失对最后得到的融合图片的影响权重 content_weight = 1 style_weight = 1000 #搭建图像风格迁移模型的代码如下: new_model = torch.nn.Sequential() #建立空的模型 model = copy.deepcopy(cnn) #deepcopy深复制,将被复制对象完全再复制一遍作为独立的新个体单独存在,改变原有被复制对象不会对已经复制出来的新对象产生影响。 #copy浅复制,并不会产生一个独立的对象单独存在,他只是将原有的数据块打上一个新标签 #所以当其中一个标签被改变的时候,数据块就会发生变化,另一个标签也会随之改变。 gram = Gram_matrix() use_gpu = torch.cuda.is_available() if use_gpu: model = model.cuda() new_model = new_model.cuda() gram = gram.cuda() index = 1 #只使用迁移模型特征提取部分的前8层 for layer in list(model)[:8]: if isinstance(layer,torch.nn.Conv2d): name = "Conv_" + str(index) #使用add_module方法向空的模型加入指定的层次模块 new_model.add_module(name,layer) if name in content_layer: target = new_model(content_img).clone() content_loss = Content_loss(content_weight,target) new_model.add_module("content_loss_"+str(index),content_loss) content_losses.append(content_loss) if name in style_layer: target = new_model(style_img).clone() target = gram(target) style_loss = Style_loss(style_weight,target) new_model.add_module("style_loss_"+str(index),style_loss) style_losses.append(style_loss) if isinstance(layer,torch.nn.ReLU): name = "ReLU_"+str(index) new_model.add_module(name,layer) index = index + 1 if isinstance(layer,torch.nn.MaxPool2d): name = "MaxPool_"+str(index) new_model.add_module(name,layer) #构造优化器 input_img = content_img.clone() parameter = torch.nn.Parameter(input_img.data) optimizer = torch.optim.LBFGS([parameter]) #--------------------------4.模型训练和参数优化 epoch_n = 300 epoch = [0] while epoch[0] <= epoch_n: def closure(): optimizer.zero_grad() style_score = 0 content_score = 0 parameter.data.clamp_(0,1) new_model(parameter) for sl in style_losses: style_score += sl.backward() for cl in content_losses: content_score += cl.backward() epoch[0] += 1 if epoch[0] % 50 == 0: print('Epoch:{} Style_loss: {:4f} Content_loss: {:.4f}'.format(epoch[0], style_score.data.item(), content_score.data.item())) return style_score + content_score optimizer.step(closure) #--------------------------5.对风格迁移图片输出 output = parameter.data unloader = transforms.ToPILImage() # 重新转化成PIL图像格式 plt.ion() plt.figure() def imshow(tensor, title=None): image = tensor.clone().cpu() # 克隆tensor,改变时不影响被克隆的tensor image = image.view(3, 224, 224) # 转换维度 image = unloader(image) plt.imshow(image) if title is not None: plt.title(title) plt.pause(0.001) # 稍作停顿,以便更新图表 imshow(output, title='Output Image') # 设置sphinx_gallery_thumbnail_number = 4 plt.ioff() plt.show()