You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
183 lines
7.5 KiB
183 lines
7.5 KiB
|
|
import os
|
|
import sys
|
|
import pickle
|
|
import cv2
|
|
from skimage import io
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
import torch.nn.functional as F
|
|
import torchvision.transforms as transforms
|
|
import pandas as pd
|
|
from skimage.transform import rotate
|
|
from utils import random_click
|
|
import random
|
|
from monai.transforms import LoadImaged, Randomizable,LoadImage
|
|
|
|
|
|
class ISIC2016(Dataset):
|
|
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
|
|
# Check if data_path contains ISIC2016_Part3 (without B)
|
|
if 'ISIC2016_Part3' in data_path:
|
|
# Handle the case with different directory structure
|
|
if mode == 'Test':
|
|
# For Test mode in ISIC2016_Part3, use CSV from ISIC directory
|
|
csv_path = os.path.join(os.path.dirname(data_path), 'ISIC', 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv')
|
|
else:
|
|
# For Training mode, use the path without B
|
|
csv_path = os.path.join(data_path, 'ISBI2016_ISIC_Part3_' + mode + '_GroundTruth.csv')
|
|
else:
|
|
# Original behavior for other data paths
|
|
csv_path = os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv')
|
|
|
|
df = pd.read_csv(csv_path, encoding='gbk', header=None)
|
|
self.name_list = df.iloc[:,0].tolist()
|
|
self.data_path = data_path
|
|
self.mode = mode
|
|
self.prompt = prompt
|
|
self.img_size = args.image_size
|
|
|
|
self.transform = transform
|
|
self.transform_msk = transform_msk
|
|
|
|
def __len__(self):
|
|
return len(self.name_list)
|
|
|
|
def __getitem__(self, index):
|
|
inout = 1
|
|
point_label = 1
|
|
|
|
"""Get the images"""
|
|
name = self.name_list[index]
|
|
|
|
# Determine data directory based on data_path
|
|
if 'ISIC2016_Part3' in self.data_path:
|
|
# For ISIC2016_Part3, use directory without B
|
|
data_dir = 'ISBI2016_ISIC_Part3_' + self.mode + '_Data'
|
|
else:
|
|
# Original behavior
|
|
data_dir = 'ISBI2016_ISIC_Part3B_' + self.mode + '_Data'
|
|
|
|
img_path = os.path.join(self.data_path, data_dir, name + '.jpg')
|
|
|
|
# Check if we're in Test mode for ISIC2016_Part3
|
|
if self.mode == 'Test' and 'ISIC2016_Part3' in self.data_path:
|
|
# In Test mode for ISIC2016_Part3, don't load mask (it doesn't exist)
|
|
img = Image.open(img_path).convert('RGB')
|
|
# Create a dummy mask for compatibility
|
|
mask = Image.new('L', (self.img_size, self.img_size), 0)
|
|
else:
|
|
# Normal case: load mask
|
|
mask_name = name + '_Segmentation.png'
|
|
msk_path = os.path.join(self.data_path, data_dir, mask_name)
|
|
img = Image.open(img_path).convert('RGB')
|
|
mask = Image.open(msk_path).convert('L')
|
|
|
|
newsize = (self.img_size, self.img_size)
|
|
mask = mask.resize(newsize)
|
|
|
|
if self.prompt == 'click':
|
|
# For test mode with dummy mask, avoid calling random_click
|
|
if self.mode == 'Test' and 'ISIC2016_Part3' in self.data_path:
|
|
# Generate a random point instead of using random_click
|
|
pt = np.array([[random.randint(0, self.img_size-1), random.randint(0, self.img_size-1)]])
|
|
else:
|
|
pt = np.array([random_click(np.array(mask) / 255, point_label, inout)])
|
|
|
|
if self.transform:
|
|
state = torch.get_rng_state()
|
|
img = self.transform(img)
|
|
torch.set_rng_state(state)
|
|
|
|
|
|
if self.transform_msk:
|
|
mask = self.transform_msk(mask)
|
|
|
|
|
|
name = name.split('/')[-1].split(".jpg")[0]
|
|
image_meta_dict = {'filename_or_obj':name}
|
|
return {
|
|
'image':img,
|
|
'label': mask,
|
|
'p_label':point_label,
|
|
'pt':pt,
|
|
'image_meta_dict':image_meta_dict,
|
|
}
|
|
|
|
class REFUGE(Dataset):
|
|
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
|
|
self.data_path = data_path
|
|
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir() and '.ipynb_checkpoints' not in f.path]
|
|
self.mode = mode
|
|
self.prompt = prompt
|
|
self.img_size = args.image_size
|
|
self.mask_size = args.out_size
|
|
|
|
self.transform = transform
|
|
self.transform_msk = transform_msk
|
|
|
|
def __len__(self):
|
|
return len(self.subfolders)
|
|
|
|
def __getitem__(self, index):
|
|
inout = 1
|
|
point_label = 1
|
|
|
|
"""Get the images"""
|
|
subfolder = self.subfolders[index]
|
|
name = subfolder.split('/')[-1]
|
|
|
|
# raw image and raters path
|
|
img_path = os.path.join(subfolder, name + '.jpg')
|
|
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)]
|
|
multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]
|
|
|
|
# raw image and raters images
|
|
img = Image.open(img_path).convert('RGB')
|
|
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
|
|
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
|
|
|
|
# resize raters images for generating initial point click
|
|
newsize = (self.img_size, self.img_size)
|
|
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
|
|
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]
|
|
|
|
# first click is the target agreement among all raters
|
|
if self.prompt == 'click':
|
|
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
|
|
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)
|
|
|
|
if self.transform:
|
|
state = torch.get_rng_state()
|
|
img = self.transform(img)
|
|
multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup]
|
|
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
|
|
# transform to mask size (out_size) for mask define
|
|
mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
|
|
|
|
multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc]
|
|
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
|
|
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
|
|
torch.set_rng_state(state)
|
|
|
|
image_meta_dict = {'filename_or_obj':name}
|
|
return {
|
|
'image':img,
|
|
'multi_rater_cup': multi_rater_cup,
|
|
'multi_rater_disc': multi_rater_disc,
|
|
'mask_cup': mask_cup,
|
|
'mask_disc': mask_disc,
|
|
'label': mask_disc,
|
|
'p_label':point_label,
|
|
'pt_cup':pt_cup,
|
|
'pt_disc':pt_disc,
|
|
'pt':pt_disc,
|
|
'selected_rater': torch.tensor(np.arange(7)),
|
|
'image_meta_dict':image_meta_dict,
|
|
}
|
|
|
|
|