You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.1 KiB
68 lines
2.1 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2023/6/15
|
|
@author:shawn-sheep
|
|
图像格式转换的工具;
|
|
"""
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
from torch.autograd import Variable
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
|
|
tensor2pil = transforms.ToPILImage() # 转回 PIL 图像
|
|
|
|
|
|
class Tool(object):
|
|
def __init__(self):
|
|
self.use_cuda, self.dtype, self.imsize = self.config()
|
|
|
|
def config(self):
|
|
"""根据是否有gpu,返回相关数据;"""
|
|
use_cuda = torch.cuda.is_available() # 是否有可用的gpu
|
|
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor # 根据cpu/gpu选择数据类型
|
|
imsize = 512 if use_cuda else 128 # 如果没有 GPU 则使用小尺寸
|
|
return use_cuda, dtype, imsize
|
|
|
|
def image_loader(self, fn):
|
|
"""读入图片并转为tensor;
|
|
param:
|
|
fn: str,图像路径;
|
|
return:
|
|
img: torch.Tensor,图像的张量, shape=([1, 3, 512, 512]);"""
|
|
img = Image.open(fn)
|
|
if img.size[0] == img.size[1]:
|
|
pass
|
|
else:
|
|
size = min(img.size)
|
|
img = img.resize((size, size))
|
|
img = Variable(self.img2tensor(img))
|
|
# 由于神经网络输入的需要, 添加 batch 的维度
|
|
# 使其shape变为(1,N_channel,height,weight)
|
|
img = img.unsqueeze(0)
|
|
return img
|
|
|
|
def img2tensor(self, img):
|
|
"""将PIL.Image.Image转为tensor;
|
|
param:
|
|
img: PIL.Image.Image,图像;
|
|
return:
|
|
torch.Tensor;"""
|
|
trans = transforms.Compose([
|
|
transforms.Resize(self.imsize), # 缩放图像
|
|
transforms.ToTensor()]) # 将其转化为 torch 张量
|
|
return trans(img)
|
|
|
|
def batch_tensor2pil(self, batch_tensor):
|
|
"""将(batch_size,n_channel,height,weight)格式的tensor
|
|
转为PIL.Image.Image
|
|
param:
|
|
tensor: torch.Tensor,
|
|
return:
|
|
img: PIL.Image.Image;"""
|
|
arr = batch_tensor.clone().cpu()
|
|
arr = arr.view(3, self.imsize, self.imsize)
|
|
img = tensor2pil(arr)
|
|
return img
|