You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.4 KiB

# -*- coding: utf-8 -*-
"""
@File : situation1.py
@Author: csc
@Date : 2022/6/23
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
from utils import *
from models import *
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
width = 512
style_img_path = '../smallWikiArt/train/'
content_img_path = '../smallCOCO2014/train2014/'
style_img_name = '30925'
content_img_name = 'COCO_train2014_000000505057'
style_img = read_image(style_img_path + style_img_name + '.jpg', target_width=width).to(device)
content_img = read_image(content_img_path + content_img_name + '.jpg', target_width=width).to(device)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
imshow(style_img, title='Style Image')
plt.subplot(1, 2, 2)
imshow(content_img, title='Content Image')
vgg16 = models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth'))
vgg16 = VGG(vgg16.features[:23]).to(device).eval()
style_features = vgg16(style_img)
content_features = vgg16(content_img)
# [x.shape for x in content_features]
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
style_grams = [gram_matrix(x) for x in style_features]
# [x.shape for x in style_grams]
input_img = content_img.clone()
optimizer = optim.LBFGS([input_img.requires_grad_()])
style_weight = 1e6
content_weight = 1
run = [0]
while run[0] <= 300:
def f():
optimizer.zero_grad()
features = vgg16(input_img)
content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
style_loss = 0
grams = [gram_matrix(x) for x in features]
for a, b in zip(grams, style_grams):
style_loss += F.mse_loss(a, b) * style_weight
loss = style_loss + content_loss
if run[0] % 50 == 0:
print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format(
run[0], style_loss.item(), content_loss.item()))
run[0] += 1
loss.backward()
return loss
optimizer.step(f)
input_img = recover_image(input_img)
cv2.imwrite('./images/out_' + style_img_name + '_' + content_img_name + '.jpg', input_img)