You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.5 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
from flask import app
from torchvision.utils import save_image
from utils import *
from models import *
import numpy as np
import os
from output import out_path
# %matplotlib inline
# %config InlineBackend.figure_format = 'retina'
def style_transfer(decide,cnt):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_style(decide)
de_resize(cnt)
width = 512
style_img = read_image('./img_src/style.jpg', target_width=width).to(device)
content_img = read_image(out_path(cnt), target_width=width).to(device)
# plt.figure(figsize=(12, 6))
# plt.subplot(1, 2, 1)
# imshow(style_img, title='Style Image')
# plt.subplot(1, 2, 2)
# imshow(content_img, title='Content Image')
vgg16 = models.vgg16(pretrained=True)
vgg16 = VGG(vgg16.features[:23]).to(device).eval()
style_features = vgg16(style_img)
content_features = vgg16(content_img)
[x.shape for x in content_features]
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
style_grams = [gram_matrix(x) for x in style_features]
[x.shape for x in style_grams]
input_img = content_img.clone()
optimizer = optim.LBFGS([input_img.requires_grad_()])
style_weight = 1e7
content_weight = 1
run = [0]
while run[0] <= 300:
def f():
optimizer.zero_grad()
features = vgg16(input_img)
content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
style_loss = 0
grams = [gram_matrix(x) for x in features]
for a, b in zip(grams, style_grams):
style_loss += F.mse_loss(a, b) * style_weight
loss = style_loss + content_loss
if run[0] % 10 == 0:
print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format(
int(run[0]/10), style_loss.item()/style_weight, content_loss.item()/content_weight))
run[0] += 1
loss.backward()
return loss
optimizer.step(f)
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
imshow(style_img, title='Style Image')
plt.subplot(1, 3, 2)
imshow(content_img, title='Content Image')
plt.subplot(1, 3, 3)
imshow(input_img, title='Output Image')
plt.savefig('./img_src/transfer.jpg')
plt.figure(figsize=(8, 8))
plt.subplot(1, 1, 1)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
imshow(input_img)
plt.savefig(out_path(cnt+1))
def de_resize(cnt):
img = cv2.imread(out_path(cnt))
img = cv2.resize(img, (612, 612), interpolation=cv2.INTER_LINEAR)
cv2.imwrite(out_path(cnt), img)
img = cv2.imread('./img_src/style.jpg')
img = cv2.resize(img, (612, 612), interpolation=cv2.INTER_LINEAR)
cv2.imwrite('./img_src/style.jpg', img)
def set_style(decide):
file_name = './img_src/styles/style'+ str(int(decide)) +'.jpg'
if os.path.exists(file_name):
img = cv2.imread(file_name)
cv2.imwrite('./img_src/style.jpg', img)