You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
@File : utils.py
@Author: csc
@Date : 2022/6/23
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import shutil
cnn_normalization_mean = [0.485, 0.456, 0.406]
cnn_normalization_std = [0.229, 0.224, 0.225]
tensor_normalizer = transforms.Normalize(mean=cnn_normalization_mean, std=cnn_normalization_std)
epsilon = 1e-5
def preprocess_image(image, target_width=None):
"""输入 PIL.Image 对象,输出标准化后的四维 tensor"""
if target_width:
t = transforms.Compose([
transforms.Resize(target_width),
transforms.CenterCrop(target_width),
transforms.ToTensor(),
tensor_normalizer,
])
else:
t = transforms.Compose([
transforms.ToTensor(),
tensor_normalizer,
])
return t(image).unsqueeze(0)
def image_to_tensor(image, target_width=None):
"""输入 OpenCV 图像,范围 0~255BGR 顺序,输出标准化后的四维 tensor"""
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
return preprocess_image(image, target_width)
def read_image(path, target_width=None):
"""输入图像路径,输出标准化后的四维 tensor"""
image = Image.open(path).convert('RGB')
# print(np.array(image).shape)
return preprocess_image(image, target_width)
def recover_image(tensor):
"""输入 GPU 上的四维 tensor输出 0~255 范围的三维 numpy 矩阵RGB 顺序"""
image = tensor.detach().cpu().numpy()
image = image * np.array(cnn_normalization_std).reshape((1, 3, 1, 1)) + \
np.array(cnn_normalization_mean).reshape((1, 3, 1, 1))
return (image.transpose(0, 2, 3, 1) * 255.).clip(0, 255).astype(np.uint8)[0]
def recover_tensor(tensor):
m = torch.tensor(cnn_normalization_mean).view(1, 3, 1, 1).to(tensor.device)
s = torch.tensor(cnn_normalization_std).view(1, 3, 1, 1).to(tensor.device)
tensor = tensor * s + m
return tensor.clamp(0, 1)
def imshow(tensor, title=None):
"""输入 GPU 上的四维 tensor然后绘制该图像"""
image = recover_image(tensor)
print(image.shape)
plt.imshow(image)
if title is not None:
plt.title(title)
def mean_std(features):
"""输入 VGG16 计算的四个特征输出每张特征图的均值和标准差长度为1920"""
mean_std_features = []
for x in features:
x = x.view(*x.shape[:2], -1)
x = torch.cat([x.mean(-1), torch.sqrt(x.var(-1) + epsilon)], dim=-1)
n = x.shape[0]
x2 = x.view(n, 2, -1).transpose(2, 1).contiguous().view(n, -1) # [mean, ..., std, ...] to [mean, std, ...]
mean_std_features.append(x2)
mean_std_features = torch.cat(mean_std_features, dim=-1)
return mean_std_features
class Smooth:
# 对输入的数据进行滑动平均
def __init__(self, windowsize=100):
self.window_size = windowsize
self.data = np.zeros((self.window_size, 1), dtype=np.float32)
self.index = 0
def __iadd__(self, x):
if self.index == 0:
self.data[:] = x
self.data[self.index % self.window_size] = x
self.index += 1
return self
def __float__(self):
return float(self.data.mean())
def __format__(self, f):
return self.__float__().__format__(f)
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def move_file(src, dst, num=None, type='move'):
path_main = src
filelist_main = os.listdir(src)
filelist_main = np.array(filelist_main)
if num is None:
num = len(filelist_main)
else:
np.random.seed(42)
np.random.shuffle(filelist_main)
path_receive = dst
for i in range(num):
file = filelist_main[i]
suffix = os.path.splitext(file)[1] # 读取文件后缀名
filename = os.path.splitext(file)[0] # 读取文件名
if suffix == '.jpg':
src_path = os.path.join(path_main, file)
dst_path = path_receive + '\\' + file
if type == 'move':
shutil.move(src_path, dst_path)
elif type == 'copy':
shutil.copy(src_path, dst_path)
if __name__ == '__main__':
content_path = "../COCO2014_500/"
content_class = "train2014"
try:
shutil.rmtree(content_path + content_class)
except:
os.mkdir(content_path)
os.mkdir(content_path + content_class)
move_file("../COCO2014/train2014", content_path + content_class, num=500, type='copy')
style_path = "../WikiArt_500/"
style_class = "train"
try:
shutil.rmtree(style_path + style_class)
except:
os.mkdir(style_path)
os.mkdir(style_path + style_class)
move_file("../WikiArt/train", style_path + style_class, num=500, type='copy')