# -*- coding: utf-8 -*- """ @File : situation1.py @Author: csc @Date : 2022/6/23 """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from PIL import Image import matplotlib.pyplot as plt import torchvision.transforms as transforms import torchvision.models as models from utils import * from models import * import numpy as np device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = 'cpu' width = 512 style_img_path = '../smallWikiArt/train/' content_img_path = '../smallCOCO2014/train2014/' style_img_name = '30925' content_img_name = 'COCO_train2014_000000505057' style_img = read_image(style_img_path + style_img_name + '.jpg', target_width=width).to(device) content_img = read_image(content_img_path + content_img_name + '.jpg', target_width=width).to(device) plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) imshow(style_img, title='Style Image') plt.subplot(1, 2, 2) imshow(content_img, title='Content Image') vgg16 = models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load('./models/vgg16-397923af.pth')) vgg16 = VGG(vgg16.features[:23]).to(device).eval() style_features = vgg16(style_img) content_features = vgg16(content_img) # [x.shape for x in content_features] def gram_matrix(y): (b, ch, h, w) = y.size() features = y.view(b, ch, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (ch * h * w) return gram style_grams = [gram_matrix(x) for x in style_features] # [x.shape for x in style_grams] input_img = content_img.clone() optimizer = optim.LBFGS([input_img.requires_grad_()]) style_weight = 1e6 content_weight = 1 run = [0] while run[0] <= 300: def f(): optimizer.zero_grad() features = vgg16(input_img) content_loss = F.mse_loss(features[2], content_features[2]) * content_weight style_loss = 0 grams = [gram_matrix(x) for x in features] for a, b in zip(grams, style_grams): style_loss += F.mse_loss(a, b) * style_weight loss = style_loss + content_loss if run[0] % 50 == 0: print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format( run[0], style_loss.item(), content_loss.item())) run[0] += 1 loss.backward() return loss optimizer.step(f) input_img = recover_image(input_img) cv2.imwrite('./images/out_' + style_img_name + '_' + content_img_name + '.jpg', input_img)