You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.9 KiB

from typing import Text
from PIL import Image, ImageFile
import torchvision.transforms as transforms
import torch.utils.data as data
from .image_folder import make_dataset
from util import task
import random
class CreateDataset(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.img_paths, self.img_size = make_dataset(opt.img_file)
# provides random file for training and testing
if opt.mask_file != 'none':
self.mask_paths, self.mask_size = make_dataset(opt.mask_file)
self.transform = get_transform(opt)
def __getitem__(self, index):
# load image
img, img_path = self.load_img(index)
# load mask
mask = self.load_mask(img, index)
return {'img': img, 'img_path': img_path, 'mask': mask}
def __len__(self):
return self.img_size
def name(self):
return "inpainting dataset"
def load_img(self, index):
ImageFile.LOAD_TRUNCATED_IMAGES = True
img_path = self.img_paths[index % self.img_size]
img_pil = Image.open(img_path).convert('RGB')
img = self.transform(img_pil)
img_pil.close()
return img, img_path
def load_mask(self, img, index):
"""Load different mask types for training and testing"""
mask_type_index = random.randint(0, len(self.opt.mask_type) - 1)
mask_type = self.opt.mask_type[mask_type_index]
# center mask
if mask_type == 0:
return task.center_mask(img)
# random regular mask
if mask_type == 1:
return task.random_regular_mask(img)
# random irregular mask
if mask_type == 2:
return task.random_irregular_mask(img)
# external mask from "Image Inpainting for Irregular Holes Using Partial Convolutions (ECCV18)"
if mask_type == 3:
if self.opt.isTrain:
mask_index = random.randint(0, self.mask_size-1)
else:
mask_index = index
mask_pil = Image.open(self.mask_paths[mask_index]).convert('RGB')
size = mask_pil.size[0]
if size > mask_pil.size[1]:
size = mask_pil.size[1]
mask_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.CenterCrop([size, size]),
transforms.Resize(self.opt.fineSize),
transforms.ToTensor()
])
mask = (mask_transform(mask_pil) == 0).float()
mask_pil.close()
return mask
def dataloader(opt):
datasets = CreateDataset(opt)
dataset = data.DataLoader(datasets, batch_size=opt.batchSize, shuffle=not opt.no_shuffle, num_workers=int(opt.nThreads), drop_last=True)
return dataset
def get_transform(opt):
"""Basic process to transform PIL image to torch tensor"""
transform_list = []
osize = [opt.loadSize[0], opt.loadSize[1]]
fsize = [opt.fineSize[0], opt.fineSize[1]]
if opt.isTrain:
if opt.resize_or_crop == 'resize_and_crop':
transform_list.append(transforms.Resize(osize))
transform_list.append(transforms.RandomCrop(fsize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(fsize))
if not opt.no_augment:
transform_list.append(transforms.ColorJitter(0.0, 0.0, 0.0, 0.0))
if not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
if not opt.no_rotation:
transform_list.append(transforms.RandomRotation(3))
else:
transform_list.append(transforms.Resize(fsize))
transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)