You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
268 lines
9.7 KiB
268 lines
9.7 KiB
|
|
import os
|
|
import sys
|
|
import pickle
|
|
import cv2
|
|
from skimage import io
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
import torch.nn.functional as F
|
|
import torchvision.transforms as transforms
|
|
import pandas as pd
|
|
from skimage.transform import rotate
|
|
from utils import random_click
|
|
import random
|
|
from monai.transforms import LoadImaged, Randomizable,LoadImage
|
|
|
|
|
|
class ISIC2016(Dataset):
|
|
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
|
|
|
|
df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'), encoding='gbk')
|
|
self.name_list = df.iloc[:,1].tolist()
|
|
self.label_list = df.iloc[:,2].tolist()
|
|
self.data_path = data_path
|
|
self.mode = mode
|
|
self.prompt = prompt
|
|
self.img_size = args.image_size
|
|
|
|
self.transform = transform
|
|
self.transform_msk = transform_msk
|
|
|
|
def __len__(self):
|
|
return len(self.name_list)
|
|
|
|
def __getitem__(self, index):
|
|
inout = 1
|
|
point_label = 1
|
|
|
|
"""Get the images"""
|
|
name = self.name_list[index]
|
|
img_path = os.path.join(self.data_path, name)
|
|
|
|
mask_name = self.label_list[index]
|
|
msk_path = os.path.join(self.data_path, mask_name)
|
|
|
|
img = Image.open(img_path).convert('RGB')
|
|
mask = Image.open(msk_path).convert('L')
|
|
|
|
newsize = (self.img_size, self.img_size)
|
|
mask = mask.resize(newsize)
|
|
|
|
if self.prompt == 'click':
|
|
pt = random_click(np.array(mask) / 255, point_label, inout)
|
|
|
|
if self.transform:
|
|
state = torch.get_rng_state()
|
|
img = self.transform(img)
|
|
torch.set_rng_state(state)
|
|
|
|
|
|
if self.transform_msk:
|
|
mask = self.transform_msk(mask)
|
|
|
|
|
|
name = name.split('/')[-1].split(".jpg")[0]
|
|
image_meta_dict = {'filename_or_obj':name}
|
|
return {
|
|
'image':img,
|
|
'label': mask,
|
|
'p_label':point_label,
|
|
'pt':pt,
|
|
'image_meta_dict':image_meta_dict,
|
|
}
|
|
|
|
class PolypDataset(Dataset):
|
|
"""
|
|
息肉分割数据集,支持 CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir 等
|
|
目录结构:
|
|
data_path/
|
|
images/
|
|
xxx.png
|
|
masks/
|
|
xxx.png
|
|
"""
|
|
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False):
|
|
self.data_path = data_path
|
|
self.mode = mode
|
|
self.prompt = prompt
|
|
self.img_size = args.image_size
|
|
self.out_size = args.out_size
|
|
|
|
self.transform = transform
|
|
self.transform_msk = transform_msk
|
|
|
|
# 获取图像列表
|
|
img_dir = os.path.join(data_path, 'images')
|
|
self.img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
|
|
|
|
# 按比例划分训练集和测试集 (80% 训练, 20% 测试)
|
|
split_idx = int(len(self.img_list) * 0.8)
|
|
if mode == 'Training':
|
|
self.img_list = self.img_list[:split_idx]
|
|
else:
|
|
self.img_list = self.img_list[split_idx:]
|
|
|
|
def __len__(self):
|
|
return len(self.img_list)
|
|
|
|
def __getitem__(self, index):
|
|
point_label = 1
|
|
inout = 1
|
|
|
|
# 获取图像和掩码路径
|
|
img_name = self.img_list[index]
|
|
img_path = os.path.join(self.data_path, 'images', img_name)
|
|
mask_path = os.path.join(self.data_path, 'masks', img_name)
|
|
|
|
# 读取图像和掩码
|
|
img = Image.open(img_path).convert('RGB')
|
|
mask = Image.open(mask_path).convert('L')
|
|
|
|
# 调整掩码大小用于生成点击提示
|
|
newsize = (self.img_size, self.img_size)
|
|
mask_resized = mask.resize(newsize)
|
|
|
|
# 生成点击提示
|
|
if self.prompt == 'click':
|
|
pt = random_click(np.array(mask_resized) / 255, point_label, inout)
|
|
|
|
# 应用变换
|
|
if self.transform:
|
|
state = torch.get_rng_state()
|
|
img = self.transform(img)
|
|
torch.set_rng_state(state)
|
|
|
|
if self.transform_msk:
|
|
mask = self.transform_msk(mask)
|
|
|
|
name = img_name.split('.')[0]
|
|
image_meta_dict = {'filename_or_obj': name}
|
|
|
|
return {
|
|
'image': img,
|
|
'label': mask,
|
|
'p_label': point_label,
|
|
'pt': pt,
|
|
'image_meta_dict': image_meta_dict,
|
|
}
|
|
|
|
|
|
class CombinedPolypDataset(Dataset):
|
|
"""
|
|
合并多个息肉数据集用于训练
|
|
"""
|
|
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False):
|
|
self.datasets = []
|
|
|
|
# 支持的数据集目录
|
|
dataset_dirs = ['CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'Kvasir']
|
|
|
|
for dataset_dir in dataset_dirs:
|
|
full_path = os.path.join(data_path, dataset_dir)
|
|
if os.path.exists(full_path):
|
|
ds = PolypDataset(args, full_path, transform, transform_msk, mode, prompt, plane)
|
|
if len(ds) > 0:
|
|
self.datasets.append(ds)
|
|
print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})")
|
|
|
|
# 计算累积长度
|
|
self.cumulative_sizes = []
|
|
total = 0
|
|
for ds in self.datasets:
|
|
total += len(ds)
|
|
self.cumulative_sizes.append(total)
|
|
|
|
print(f"Total {mode} samples: {total}")
|
|
|
|
def __len__(self):
|
|
return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
|
|
|
|
def __getitem__(self, index):
|
|
# 找到对应的数据集
|
|
for i, cumsize in enumerate(self.cumulative_sizes):
|
|
if index < cumsize:
|
|
if i == 0:
|
|
return self.datasets[i][index]
|
|
else:
|
|
return self.datasets[i][index - self.cumulative_sizes[i-1]]
|
|
raise IndexError("Index out of range")
|
|
|
|
|
|
class REFUGE(Dataset):
|
|
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
|
|
self.data_path = data_path
|
|
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()]
|
|
self.mode = mode
|
|
self.prompt = prompt
|
|
self.img_size = args.image_size
|
|
self.mask_size = args.out_size
|
|
|
|
self.transform = transform
|
|
self.transform_msk = transform_msk
|
|
|
|
def __len__(self):
|
|
return len(self.subfolders)
|
|
|
|
def __getitem__(self, index):
|
|
inout = 1
|
|
point_label = 1
|
|
|
|
"""Get the images"""
|
|
subfolder = self.subfolders[index]
|
|
name = subfolder.split('/')[-1]
|
|
|
|
# raw image and raters path
|
|
img_path = os.path.join(subfolder, name + '.jpg')
|
|
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)]
|
|
multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]
|
|
|
|
# raw image and raters images
|
|
img = Image.open(img_path).convert('RGB')
|
|
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
|
|
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
|
|
|
|
# resize raters images for generating initial point click
|
|
newsize = (self.img_size, self.img_size)
|
|
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
|
|
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]
|
|
|
|
# first click is the target agreement among all raters
|
|
if self.prompt == 'click':
|
|
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
|
|
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)
|
|
|
|
if self.transform:
|
|
state = torch.get_rng_state()
|
|
img = self.transform(img)
|
|
multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup]
|
|
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
|
|
# transform to mask size (out_size) for mask define
|
|
mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
|
|
|
|
multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc]
|
|
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
|
|
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
|
|
torch.set_rng_state(state)
|
|
|
|
image_meta_dict = {'filename_or_obj':name}
|
|
return {
|
|
'image':img,
|
|
'multi_rater_cup': multi_rater_cup,
|
|
'multi_rater_disc': multi_rater_disc,
|
|
'mask_cup': mask_cup,
|
|
'mask_disc': mask_disc,
|
|
'label': mask_disc,
|
|
'p_label':point_label,
|
|
'pt_cup':pt_cup,
|
|
'pt_disc':pt_disc,
|
|
'pt':pt_disc,
|
|
'selected_rater': torch.tensor(np.arange(7)),
|
|
'image_meta_dict':image_meta_dict,
|
|
}
|
|
|
|
|