You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.1 KiB

# -*- coding: utf-8 -*-
"""
Created on 2023/6/15
@author:shawn-sheep
图像格式转换的工具;
"""
from __future__ import print_function
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from PIL import Image
tensor2pil = transforms.ToPILImage() # 转回 PIL 图像
class Tool(object):
def __init__(self):
self.use_cuda, self.dtype, self.imsize = self.config()
def config(self):
"""根据是否有gpu,返回相关数据;"""
use_cuda = torch.cuda.is_available() # 是否有可用的gpu
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor # 根据cpu/gpu选择数据类型
imsize = 512 if use_cuda else 128 # 如果没有 GPU 则使用小尺寸
return use_cuda, dtype, imsize
def image_loader(self, fn):
"""读入图片并转为tensor;
param:
fn: str,图像路径;
return:
img: torch.Tensor,图像的张量, shape=([1, 3, 512, 512]);"""
img = Image.open(fn)
if img.size[0] == img.size[1]:
pass
else:
size = min(img.size)
img = img.resize((size, size))
img = Variable(self.img2tensor(img))
# 由于神经网络输入的需要, 添加 batch 的维度
# 使其shape变为(1,N_channel,height,weight)
img = img.unsqueeze(0)
return img
def img2tensor(self, img):
"""将PIL.Image.Image转为tensor;
param:
img: PIL.Image.Image,图像;
return:
torch.Tensor;"""
trans = transforms.Compose([
transforms.Resize(self.imsize), # 缩放图像
transforms.ToTensor()]) # 将其转化为 torch 张量
return trans(img)
def batch_tensor2pil(self, batch_tensor):
"""将(batch_size,n_channel,height,weight)格式的tensor
转为PIL.Image.Image
param:
tensor: torch.Tensor,
return:
img: PIL.Image.Image;"""
arr = batch_tensor.clone().cpu()
arr = arr.view(3, self.imsize, self.imsize)
img = tensor2pil(arr)
return img