You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
7.5 KiB

import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from utils import random_click
import random
from monai.transforms import LoadImaged, Randomizable,LoadImage
class ISIC2016(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
# Check if data_path contains ISIC2016_Part3 (without B)
if 'ISIC2016_Part3' in data_path:
# Handle the case with different directory structure
if mode == 'Test':
# For Test mode in ISIC2016_Part3, use CSV from ISIC directory
csv_path = os.path.join(os.path.dirname(data_path), 'ISIC', 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv')
else:
# For Training mode, use the path without B
csv_path = os.path.join(data_path, 'ISBI2016_ISIC_Part3_' + mode + '_GroundTruth.csv')
else:
# Original behavior for other data paths
csv_path = os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv')
df = pd.read_csv(csv_path, encoding='gbk', header=None)
self.name_list = df.iloc[:,0].tolist()
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
name = self.name_list[index]
# Determine data directory based on data_path
if 'ISIC2016_Part3' in self.data_path:
# For ISIC2016_Part3, use directory without B
data_dir = 'ISBI2016_ISIC_Part3_' + self.mode + '_Data'
else:
# Original behavior
data_dir = 'ISBI2016_ISIC_Part3B_' + self.mode + '_Data'
img_path = os.path.join(self.data_path, data_dir, name + '.jpg')
# Check if we're in Test mode for ISIC2016_Part3
if self.mode == 'Test' and 'ISIC2016_Part3' in self.data_path:
# In Test mode for ISIC2016_Part3, don't load mask (it doesn't exist)
img = Image.open(img_path).convert('RGB')
# Create a dummy mask for compatibility
mask = Image.new('L', (self.img_size, self.img_size), 0)
else:
# Normal case: load mask
mask_name = name + '_Segmentation.png'
msk_path = os.path.join(self.data_path, data_dir, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
# For test mode with dummy mask, avoid calling random_click
if self.mode == 'Test' and 'ISIC2016_Part3' in self.data_path:
# Generate a random point instead of using random_click
pt = np.array([[random.randint(0, self.img_size-1), random.randint(0, self.img_size-1)]])
else:
pt = np.array([random_click(np.array(mask) / 255, point_label, inout)])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}
class REFUGE(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
self.data_path = data_path
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir() and '.ipynb_checkpoints' not in f.path]
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mask_size = args.out_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.subfolders)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
subfolder = self.subfolders[index]
name = subfolder.split('/')[-1]
# raw image and raters path
img_path = os.path.join(subfolder, name + '.jpg')
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)]
multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]
# raw image and raters images
img = Image.open(img_path).convert('RGB')
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
# resize raters images for generating initial point click
newsize = (self.img_size, self.img_size)
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]
# first click is the target agreement among all raters
if self.prompt == 'click':
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup]
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
# transform to mask size (out_size) for mask define
mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc]
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
torch.set_rng_state(state)
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater_cup': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_disc,
'p_label':point_label,
'pt_cup':pt_cup,
'pt_disc':pt_disc,
'pt':pt_disc,
'selected_rater': torch.tensor(np.arange(7)),
'image_meta_dict':image_meta_dict,
}