|
|
|
@ -13,6 +13,7 @@ from tools import Tool
|
|
|
|
|
|
|
|
|
|
tl = Tool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContentLoss(nn.Module):
|
|
|
|
|
"""内容损失"""
|
|
|
|
|
|
|
|
|
@ -68,24 +69,25 @@ class StyleLoss(ContentLoss):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Transfer(object):
|
|
|
|
|
def __init__(self, fn_content, fn_style, model_path=r'weights/squeezenet1_0-a815701f.pth'):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, fn_content, fn_style, model_path=r"weights/squeezenet1_0-a815701f.pth"
|
|
|
|
|
):
|
|
|
|
|
"""usage:
|
|
|
|
|
net = Transfer('picasso.jpg','dancing.jpg')
|
|
|
|
|
dt, img = net.fit()
|
|
|
|
|
"""
|
|
|
|
|
net = Transfer('picasso.jpg','dancing.jpg')
|
|
|
|
|
dt, img = net.fit()
|
|
|
|
|
"""
|
|
|
|
|
self.use_cuda, dtype, imsize = tl.config()
|
|
|
|
|
|
|
|
|
|
self.content_img = tl.image_loader(fn_content).type(dtype)
|
|
|
|
|
self.style_img = tl.image_loader(fn_style).type(dtype)
|
|
|
|
|
self.input_img = self.content_img.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
get_style_model_and_losses函数是针对vgg模型的,
|
|
|
|
|
要应用到其他模型,需要改写该函数;"""
|
|
|
|
|
if 'vgg19' in model_path:
|
|
|
|
|
if "vgg19" in model_path:
|
|
|
|
|
self.seq = self.load_vgg19(model_path)
|
|
|
|
|
elif 'resnet18' in model_path:
|
|
|
|
|
elif "resnet18" in model_path:
|
|
|
|
|
self.seq = self.load_resnet18(model_path)
|
|
|
|
|
elif "squeezenet1_0" in model_path:
|
|
|
|
|
self.seq = self.load_squeezenet(model_path, "1_0")
|
|
|
|
@ -104,7 +106,6 @@ class Transfer(object):
|
|
|
|
|
cnn.load_state_dict(torch.load(model_path))
|
|
|
|
|
return cnn.features[:23]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_squeezenet(self, model_path, version):
|
|
|
|
|
"""加载SqueezeNet1.0预训练模型;"""
|
|
|
|
|
model = models.SqueezeNet(version=version)
|
|
|
|
@ -113,10 +114,12 @@ class Transfer(object):
|
|
|
|
|
|
|
|
|
|
def load_densenet(self, model_path):
|
|
|
|
|
"""加载densenet121预训练模型;"""
|
|
|
|
|
model = models.DenseNet(num_init_features=64, growth_rate=32,
|
|
|
|
|
block_config=(6, 12, 24, 16))
|
|
|
|
|
model = models.DenseNet(
|
|
|
|
|
num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
|
|
|
|
|
)
|
|
|
|
|
pattern = re.compile(
|
|
|
|
|
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
|
|
|
|
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
|
|
|
|
|
)
|
|
|
|
|
state_dict = torch.load(model_path)
|
|
|
|
|
for key in list(state_dict.keys()):
|
|
|
|
|
res = pattern.match(key)
|
|
|
|
@ -146,19 +149,33 @@ class Transfer(object):
|
|
|
|
|
outout_img:PIL.Image.Image;
|
|
|
|
|
style_weight需要远远大于content_weight;"""
|
|
|
|
|
t0 = time.time()
|
|
|
|
|
cnn, tensor = self.rebuild(self.seq, self.content_img,
|
|
|
|
|
self.style_img, self.input_img,
|
|
|
|
|
num_steps, content_weight,
|
|
|
|
|
style_weight)
|
|
|
|
|
cnn, tensor = self.rebuild(
|
|
|
|
|
self.seq,
|
|
|
|
|
self.content_img,
|
|
|
|
|
self.style_img,
|
|
|
|
|
self.input_img,
|
|
|
|
|
num_steps,
|
|
|
|
|
content_weight,
|
|
|
|
|
style_weight,
|
|
|
|
|
)
|
|
|
|
|
output_img = tl.batch_tensor2pil(tensor)
|
|
|
|
|
dt = time.time() - t0
|
|
|
|
|
return dt, output_img
|
|
|
|
|
|
|
|
|
|
def rebuild(self, cnn, content_img, style_img, input_img, num_steps,
|
|
|
|
|
content_weight, style_weight):
|
|
|
|
|
def rebuild(
|
|
|
|
|
self,
|
|
|
|
|
cnn,
|
|
|
|
|
content_img,
|
|
|
|
|
style_img,
|
|
|
|
|
input_img,
|
|
|
|
|
num_steps,
|
|
|
|
|
content_weight,
|
|
|
|
|
style_weight,
|
|
|
|
|
):
|
|
|
|
|
"""run the style transfer."""
|
|
|
|
|
model, style_losses, content_losses = self.get_losses(cnn, style_img, content_img,
|
|
|
|
|
style_weight, content_weight)
|
|
|
|
|
model, style_losses, content_losses = self.get_losses(
|
|
|
|
|
cnn, style_img, content_img, style_weight, content_weight
|
|
|
|
|
)
|
|
|
|
|
input_param, optimizer = self.get_optimizer(input_img)
|
|
|
|
|
|
|
|
|
|
run = [0]
|
|
|
|
@ -180,8 +197,11 @@ class Transfer(object):
|
|
|
|
|
|
|
|
|
|
run[0] += 1
|
|
|
|
|
if run[0] % 50 == 0:
|
|
|
|
|
print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score, content_score))
|
|
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
"Style Loss : {:4f} Content Loss: {:4f}".format(
|
|
|
|
|
style_score, content_score
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return style_score + content_score
|
|
|
|
|
|
|
|
|
@ -192,10 +212,16 @@ class Transfer(object):
|
|
|
|
|
|
|
|
|
|
return model, input_param.data
|
|
|
|
|
|
|
|
|
|
def get_losses(self, cnn, style_img, content_img,
|
|
|
|
|
style_weight, content_weight,
|
|
|
|
|
content_layers=['conv_4'],
|
|
|
|
|
style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']):
|
|
|
|
|
def get_losses(
|
|
|
|
|
self,
|
|
|
|
|
cnn,
|
|
|
|
|
style_img,
|
|
|
|
|
content_img,
|
|
|
|
|
style_weight,
|
|
|
|
|
content_weight,
|
|
|
|
|
content_layers=["conv_4"],
|
|
|
|
|
style_layers=["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"],
|
|
|
|
|
):
|
|
|
|
|
cnn = copy.deepcopy(cnn)
|
|
|
|
|
|
|
|
|
|
# 仅为了有一个可迭代的列表 内容/风格 损失
|
|
|
|
@ -264,21 +290,22 @@ class Transfer(object):
|
|
|
|
|
input_param = nn.Parameter(input_img.data)
|
|
|
|
|
optimizer = optim.LBFGS([input_param])
|
|
|
|
|
return input_param, optimizer
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# model = models.vgg19(pretrained=True)
|
|
|
|
|
ph = 'weights/vgg19-dcbb9e9d.pth'
|
|
|
|
|
ph = "weights/vgg19-dcbb9e9d.pth"
|
|
|
|
|
|
|
|
|
|
transfer = Transfer('ori/2.jpg', 'art/2.jpg',ph)
|
|
|
|
|
transfer = Transfer("./ori/2.jpg", "./art/2.jpg", ph)
|
|
|
|
|
|
|
|
|
|
t = time.time()
|
|
|
|
|
dt, img = transfer.fit()
|
|
|
|
|
|
|
|
|
|
print(time.time()-t)
|
|
|
|
|
print(time.time() - t)
|
|
|
|
|
# print(dt,img)
|
|
|
|
|
|
|
|
|
|
img = np.array(img)[:,:,::-1]
|
|
|
|
|
img = np.array(img)[:, :, ::-1]
|
|
|
|
|
|
|
|
|
|
cv2.imwrite('1.jpg',img)
|
|
|
|
|
cv2.imshow('1',img)
|
|
|
|
|
cv2.imwrite("1.jpg", img)
|
|
|
|
|
cv2.imshow("1", img)
|
|
|
|
|
cv2.waitKey(0)
|