diff --git a/styleTransfer.py b/styleTransfer.py new file mode 100644 index 0000000..9d32f18 --- /dev/null +++ b/styleTransfer.py @@ -0,0 +1,168 @@ +#--------------------------导包 +import torch +from torch.autograd import Variable +import torchvision +from torchvision import transforms, models +import copy +from PIL import Image +import matplotlib.pyplot as plt + +#--------------------------1.数据预处理,加载数据 +transform = transforms.Compose([transforms.Resize([224,224]), + transforms.ToTensor()]) + +def loadimg(path = None): + img = Image.open(path) + img = transform(img) + img = img.unsqueeze(0) + return img + +content_img = loadimg('images/1.jpg') #入参是自己存放图片的位置 +content_img = Variable(content_img).cuda() +style_img = loadimg('images/2.jpg') +style_img = Variable(style_img).cuda() + +#--------------------------2.定义内容损失和风格损失 +class Content_loss(torch.nn.Module): + def __init__(self,weight,target): + super(Content_loss,self).__init__() + self.weight = weight + self.target = target.detach()*weight + self.loss_fn = torch.nn.MSELoss() + + def forward(self,in_put): + self.loss = self.loss_fn(in_put*self.weight,self.target) + return in_put + + def backward(self): + self.loss.backward(retain_graph = True) + return self.loss + +class Gram_matrix(torch.nn.Module): + def forward(self,in_put): + a,b,c,d = in_put.size() + feature = in_put.view(a*b,c*d) + gram = torch.mm(feature,feature.t()) + return gram.div(a*b*c*d) + +class Style_loss(torch.nn.Module): + def __init__(self,weight,target): + super(Style_loss,self).__init__() + self.weight = weight + self.target = target.detach()*weight + self.loss_fn = torch.nn.MSELoss() + self.gram = Gram_matrix() + + def forward(self,in_put): + self.Gram = self.gram(in_put.clone()) + self.Gram.mul_(self.weight) + self.loss = self.loss_fn(self.Gram,self.target) + return in_put + + def backward(self): + self.loss.backward(retain_graph = True) + return self.loss + +#--------------------------3.模型搭建 +cnn = models.vgg16(pretrained = True).features #迁移VGG16架构的特征提取部分 +#指定整个卷积过程中分别在哪一层提取内容和风格 +content_layer = ["Conv_3"] +style_layer = ["Conv_1","Conv_2","Conv_3","Conv_4"] +#定义保存内容损失和风格损失的列表 +content_losses = [] +style_losses = [] +#指定内容损失和风格损失对最后得到的融合图片的影响权重 +content_weight = 1 +style_weight = 1000 + +#搭建图像风格迁移模型的代码如下: +new_model = torch.nn.Sequential() #建立空的模型 +model = copy.deepcopy(cnn) +#deepcopy深复制,将被复制对象完全再复制一遍作为独立的新个体单独存在,改变原有被复制对象不会对已经复制出来的新对象产生影响。 +#copy浅复制,并不会产生一个独立的对象单独存在,他只是将原有的数据块打上一个新标签 +#所以当其中一个标签被改变的时候,数据块就会发生变化,另一个标签也会随之改变。 +gram = Gram_matrix() + +use_gpu = torch.cuda.is_available() +if use_gpu: + model = model.cuda() + new_model = new_model.cuda() + gram = gram.cuda() + +index = 1 +#只使用迁移模型特征提取部分的前8层 +for layer in list(model)[:8]: + if isinstance(layer,torch.nn.Conv2d): + name = "Conv_" + str(index) + #使用add_module方法向空的模型加入指定的层次模块 + new_model.add_module(name,layer) + if name in content_layer: + target = new_model(content_img).clone() + content_loss = Content_loss(content_weight,target) + new_model.add_module("content_loss_"+str(index),content_loss) + content_losses.append(content_loss) + + if name in style_layer: + target = new_model(style_img).clone() + target = gram(target) + style_loss = Style_loss(style_weight,target) + new_model.add_module("style_loss_"+str(index),style_loss) + style_losses.append(style_loss) + + if isinstance(layer,torch.nn.ReLU): + name = "ReLU_"+str(index) + new_model.add_module(name,layer) + index = index + 1 + + if isinstance(layer,torch.nn.MaxPool2d): + name = "MaxPool_"+str(index) + new_model.add_module(name,layer) +#构造优化器 +input_img = content_img.clone() +parameter = torch.nn.Parameter(input_img.data) +optimizer = torch.optim.LBFGS([parameter]) + +#--------------------------4.模型训练和参数优化 +epoch_n = 300 +epoch = [0] +while epoch[0] <= epoch_n: + + def closure(): + optimizer.zero_grad() + style_score = 0 + content_score = 0 + parameter.data.clamp_(0,1) + new_model(parameter) + for sl in style_losses: + style_score += sl.backward() + + for cl in content_losses: + content_score += cl.backward() + + epoch[0] += 1 + if epoch[0] % 50 == 0: + print('Epoch:{} Style_loss: {:4f} Content_loss: {:.4f}'.format(epoch[0], style_score.data.item(), + content_score.data.item())) + return style_score + content_score + + optimizer.step(closure) + +#--------------------------5.对风格迁移图片输出 +output = parameter.data +unloader = transforms.ToPILImage() # 重新转化成PIL图像格式 + +plt.ion() +plt.figure() +def imshow(tensor, title=None): + image = tensor.clone().cpu() # 克隆tensor,改变时不影响被克隆的tensor + image = image.view(3, 224, 224) # 转换维度 + image = unloader(image) + plt.imshow(image) + if title is not None: + plt.title(title) + plt.pause(0.001) # 稍作停顿,以便更新图表 +imshow(output, title='Output Image') + +# 设置sphinx_gallery_thumbnail_number = 4 +plt.ioff() +plt.show() \ No newline at end of file