You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.4 KiB
91 lines
2.4 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
@File : situation1.py
|
|
@Author: csc
|
|
@Date : 2022/6/23
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torchvision.transforms as transforms
|
|
import torchvision.models as models
|
|
|
|
from utils import *
|
|
from models import *
|
|
|
|
import numpy as np
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
device = 'cpu'
|
|
|
|
width = 512
|
|
style_img_path = '../smallWikiArt/train/'
|
|
content_img_path = '../smallCOCO2014/train2014/'
|
|
style_img_name = '30925'
|
|
content_img_name = 'COCO_train2014_000000505057'
|
|
style_img = read_image(style_img_path + style_img_name + '.jpg', target_width=width).to(device)
|
|
content_img = read_image(content_img_path + content_img_name + '.jpg', target_width=width).to(device)
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(1, 2, 1)
|
|
imshow(style_img, title='Style Image')
|
|
plt.subplot(1, 2, 2)
|
|
imshow(content_img, title='Content Image')
|
|
|
|
vgg16 = models.vgg16(pretrained=False)
|
|
vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
|
|
vgg16 = VGG(vgg16.features[:23]).to(device).eval()
|
|
|
|
style_features = vgg16(style_img)
|
|
content_features = vgg16(content_img)
|
|
# [x.shape for x in content_features]
|
|
|
|
|
|
def gram_matrix(y):
|
|
(b, ch, h, w) = y.size()
|
|
features = y.view(b, ch, w * h)
|
|
features_t = features.transpose(1, 2)
|
|
gram = features.bmm(features_t) / (ch * h * w)
|
|
return gram
|
|
|
|
style_grams = [gram_matrix(x) for x in style_features]
|
|
# [x.shape for x in style_grams]
|
|
|
|
input_img = content_img.clone()
|
|
optimizer = optim.LBFGS([input_img.requires_grad_()])
|
|
style_weight = 1e6
|
|
content_weight = 1
|
|
|
|
run = [0]
|
|
while run[0] <= 300:
|
|
def f():
|
|
optimizer.zero_grad()
|
|
features = vgg16(input_img)
|
|
|
|
content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
|
|
style_loss = 0
|
|
grams = [gram_matrix(x) for x in features]
|
|
for a, b in zip(grams, style_grams):
|
|
style_loss += F.mse_loss(a, b) * style_weight
|
|
|
|
loss = style_loss + content_loss
|
|
|
|
if run[0] % 50 == 0:
|
|
print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format(
|
|
run[0], style_loss.item(), content_loss.item()))
|
|
run[0] += 1
|
|
|
|
loss.backward()
|
|
return loss
|
|
|
|
|
|
optimizer.step(f)
|
|
|
|
input_img = recover_image(input_img)
|
|
cv2.imwrite('./images/out_' + style_img_name + '_' + content_img_name + '.jpg', input_img)
|