parent
2d558d497c
commit
a5e2e49351
@ -0,0 +1,173 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch.utils.data as data
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
|
||||
|
||||
class PolypDataset(data.Dataset):
|
||||
"""
|
||||
dataloader for polyp segmentation tasks
|
||||
"""
|
||||
def __init__(self, image_root, gt_root, trainsize, augmentations):
|
||||
self.trainsize = trainsize
|
||||
self.augmentations = augmentations
|
||||
print(self.augmentations)
|
||||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
|
||||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg')]
|
||||
self.images = sorted(self.images)
|
||||
self.gts = sorted(self.gts)
|
||||
self.filter_files()
|
||||
self.size = len(self.images)
|
||||
if self.augmentations == 'True':
|
||||
print('Using RandomRotation, RandomFlip')
|
||||
self.img_transform = transforms.Compose([
|
||||
transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.Resize((self.trainsize, self.trainsize)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406],
|
||||
[0.229, 0.224, 0.225])])
|
||||
self.gt_transform = transforms.Compose([
|
||||
transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.Resize((self.trainsize, self.trainsize)),
|
||||
transforms.ToTensor()])
|
||||
|
||||
else:
|
||||
print('no augmentation')
|
||||
self.img_transform = transforms.Compose([
|
||||
transforms.Resize((self.trainsize, self.trainsize)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406],
|
||||
[0.229, 0.224, 0.225])])
|
||||
|
||||
self.gt_transform = transforms.Compose([
|
||||
transforms.Resize((self.trainsize, self.trainsize)),
|
||||
transforms.ToTensor()])
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
image = self.rgb_loader(self.images[index])
|
||||
gt = self.binary_loader(self.gts[index])
|
||||
|
||||
seed = np.random.randint(2147483647) # make a seed with numpy generator
|
||||
random.seed(seed) # apply this seed to img tranfsorms
|
||||
torch.manual_seed(seed) # needed for torchvision 0.7
|
||||
if self.img_transform is not None:
|
||||
image = self.img_transform(image)
|
||||
|
||||
random.seed(seed) # apply this seed to img tranfsorms
|
||||
torch.manual_seed(seed) # needed for torchvision 0.7
|
||||
if self.gt_transform is not None:
|
||||
gt = self.gt_transform(gt)
|
||||
return image, gt
|
||||
|
||||
def filter_files(self):
|
||||
assert len(self.images) == len(self.gts)
|
||||
images = []
|
||||
gts = []
|
||||
for img_path, gt_path in zip(self.images, self.gts):
|
||||
img = Image.open(img_path)
|
||||
gt = Image.open(gt_path)
|
||||
if img.size == gt.size:
|
||||
images.append(img_path)
|
||||
gts.append(gt_path)
|
||||
self.images = images
|
||||
self.gts = gts
|
||||
|
||||
def rgb_loader(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
|
||||
def binary_loader(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
# return img.convert('1')
|
||||
return img.convert('L')
|
||||
|
||||
def convert2polar(self, img, gt):
|
||||
|
||||
center = polar_transformations.centroid(gt)
|
||||
img = polar_transformations.to_polar(img, center)
|
||||
gt = polar_transformations.to_polar(gt, center)
|
||||
|
||||
return img, gt
|
||||
#center_max_shift = 0.05 * LesionDataset.height
|
||||
#center = np.array(center)
|
||||
#center = (
|
||||
#center[0] + np.random.uniform(-center_max_shift, center_max_shift),
|
||||
#center[1] + np.random.uniform(-center_max_shift, center_max_shift))
|
||||
## to PyTorch expected format
|
||||
#input = input.transpose(2, 0, 1)
|
||||
#label = np.expand_dims(label, axis=-1)
|
||||
#label = label.transpose(2, 0, 1)
|
||||
|
||||
#input_tensor = torch.from_numpy(input)
|
||||
|
||||
def resize(self, img, gt):
|
||||
assert img.size == gt.size
|
||||
w, h = img.size
|
||||
if h < self.trainsize or w < self.trainsize:
|
||||
h = max(h, self.trainsize)
|
||||
w = max(w, self.trainsize)
|
||||
return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
|
||||
else:
|
||||
return img, gt
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=False, num_workers=4, pin_memory=True, augmentation=False): #shuffle=True
|
||||
|
||||
dataset = PolypDataset(image_root, gt_root, trainsize, augmentation)
|
||||
data_loader = data.DataLoader(dataset=dataset,
|
||||
batch_size=batchsize,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory)
|
||||
return data_loader
|
||||
|
||||
|
||||
class test_dataset:
|
||||
def __init__(self, image_root, gt_root, testsize):
|
||||
self.testsize = testsize
|
||||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
|
||||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png' or f.endswith('.jpg'))]
|
||||
self.images = sorted(self.images)
|
||||
self.gts = sorted(self.gts)
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((self.testsize, self.testsize)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406],
|
||||
[0.229, 0.224, 0.225])])
|
||||
self.gt_transform = transforms.ToTensor()
|
||||
self.size = len(self.images)
|
||||
self.index = 0
|
||||
|
||||
def load_data(self):
|
||||
image = self.rgb_loader(self.images[self.index])
|
||||
image = self.transform(image).unsqueeze(0)
|
||||
gt = self.binary_loader(self.gts[self.index])
|
||||
name = self.images[self.index].split('/')[-1]
|
||||
if name.endswith('.jpg'):
|
||||
name = name.split('.jpg')[0] + '.png'
|
||||
self.index += 1
|
||||
return image, gt, name
|
||||
|
||||
def rgb_loader(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
|
||||
def binary_loader(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('L')
|
||||
Loading…
Reference in new issue