import os from PIL import Image import torch.utils.data as data import torchvision.transforms as transforms import numpy as np import random import torch class PolypDataset(data.Dataset): """ dataloader for polyp segmentation tasks """ def __init__(self, image_root, gt_root, trainsize, augmentations): self.trainsize = trainsize self.augmentations = augmentations print(self.augmentations) self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg')] self.images = sorted(self.images) self.gts = sorted(self.gts) self.filter_files() self.size = len(self.images) if self.augmentations == 'True': print('Using RandomRotation, RandomFlip') self.img_transform = transforms.Compose([ transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None), transforms.RandomVerticalFlip(p=0.5), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.gt_transform = transforms.Compose([ transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None), transforms.RandomVerticalFlip(p=0.5), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) else: print('no augmentation') self.img_transform = transforms.Compose([ transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.gt_transform = transforms.Compose([ transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) def __getitem__(self, index): image = self.rgb_loader(self.images[index]) gt = self.binary_loader(self.gts[index]) seed = np.random.randint(2147483647) # make a seed with numpy generator random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 if self.img_transform is not None: image = self.img_transform(image) random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 if self.gt_transform is not None: gt = self.gt_transform(gt) return image, gt def filter_files(self): assert len(self.images) == len(self.gts) images = [] gts = [] for img_path, gt_path in zip(self.images, self.gts): img = Image.open(img_path) gt = Image.open(gt_path) if img.size == gt.size: images.append(img_path) gts.append(gt_path) self.images = images self.gts = gts def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) # return img.convert('1') return img.convert('L') def convert2polar(self, img, gt): center = polar_transformations.centroid(gt) img = polar_transformations.to_polar(img, center) gt = polar_transformations.to_polar(gt, center) return img, gt #center_max_shift = 0.05 * LesionDataset.height #center = np.array(center) #center = ( #center[0] + np.random.uniform(-center_max_shift, center_max_shift), #center[1] + np.random.uniform(-center_max_shift, center_max_shift)) ## to PyTorch expected format #input = input.transpose(2, 0, 1) #label = np.expand_dims(label, axis=-1) #label = label.transpose(2, 0, 1) #input_tensor = torch.from_numpy(input) def resize(self, img, gt): assert img.size == gt.size w, h = img.size if h < self.trainsize or w < self.trainsize: h = max(h, self.trainsize) w = max(w, self.trainsize) return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) else: return img, gt def __len__(self): return self.size def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=False, num_workers=4, pin_memory=True, augmentation=False): #shuffle=True dataset = PolypDataset(image_root, gt_root, trainsize, augmentation) data_loader = data.DataLoader(dataset=dataset, batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) return data_loader class test_dataset: def __init__(self, image_root, gt_root, testsize): self.testsize = testsize self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png' or f.endswith('.jpg'))] self.images = sorted(self.images) self.gts = sorted(self.gts) self.transform = transforms.Compose([ transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.gt_transform = transforms.ToTensor() self.size = len(self.images) self.index = 0 def load_data(self): image = self.rgb_loader(self.images[self.index]) image = self.transform(image).unsqueeze(0) gt = self.binary_loader(self.gts[self.index]) name = self.images[self.index].split('/')[-1] if name.endswith('.jpg'): name = name.split('.jpg')[0] + '.png' self.index += 1 return image, gt, name def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('L')