import os import sys import pickle import cv2 from skimage import io import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset from PIL import Image import torch.nn.functional as F import torchvision.transforms as transforms import pandas as pd from skimage.transform import rotate from utils import random_click import random from monai.transforms import LoadImaged, Randomizable,LoadImage class ISIC2016(Dataset): def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'), encoding='gbk') self.name_list = df.iloc[:,1].tolist() self.label_list = df.iloc[:,2].tolist() self.data_path = data_path self.mode = mode self.prompt = prompt self.img_size = args.image_size self.transform = transform self.transform_msk = transform_msk def __len__(self): return len(self.name_list) def __getitem__(self, index): inout = 1 point_label = 1 """Get the images""" name = self.name_list[index] img_path = os.path.join(self.data_path, name) mask_name = self.label_list[index] msk_path = os.path.join(self.data_path, mask_name) img = Image.open(img_path).convert('RGB') mask = Image.open(msk_path).convert('L') newsize = (self.img_size, self.img_size) mask = mask.resize(newsize) if self.prompt == 'click': pt = random_click(np.array(mask) / 255, point_label, inout) if self.transform: state = torch.get_rng_state() img = self.transform(img) torch.set_rng_state(state) if self.transform_msk: mask = self.transform_msk(mask) name = name.split('/')[-1].split(".jpg")[0] image_meta_dict = {'filename_or_obj':name} return { 'image':img, 'label': mask, 'p_label':point_label, 'pt':pt, 'image_meta_dict':image_meta_dict, } class PolypDataset(Dataset): """ 息肉分割数据集,支持 CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir 等 目录结构: data_path/ images/ xxx.png masks/ xxx.png """ def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False): self.data_path = data_path self.mode = mode self.prompt = prompt self.img_size = args.image_size self.out_size = args.out_size self.transform = transform self.transform_msk = transform_msk # 获取图像列表 img_dir = os.path.join(data_path, 'images') self.img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]) # 按比例划分训练集和测试集 (80% 训练, 20% 测试) split_idx = int(len(self.img_list) * 0.8) if mode == 'Training': self.img_list = self.img_list[:split_idx] else: self.img_list = self.img_list[split_idx:] def __len__(self): return len(self.img_list) def __getitem__(self, index): point_label = 1 inout = 1 # 获取图像和掩码路径 img_name = self.img_list[index] img_path = os.path.join(self.data_path, 'images', img_name) mask_path = os.path.join(self.data_path, 'masks', img_name) # 读取图像和掩码 img = Image.open(img_path).convert('RGB') mask = Image.open(mask_path).convert('L') # 调整掩码大小用于生成点击提示 newsize = (self.img_size, self.img_size) mask_resized = mask.resize(newsize) # 生成点击提示 if self.prompt == 'click': pt = random_click(np.array(mask_resized) / 255, point_label, inout) # 应用变换 if self.transform: state = torch.get_rng_state() img = self.transform(img) torch.set_rng_state(state) if self.transform_msk: mask = self.transform_msk(mask) name = img_name.split('.')[0] image_meta_dict = {'filename_or_obj': name} return { 'image': img, 'label': mask, 'p_label': point_label, 'pt': pt, 'image_meta_dict': image_meta_dict, } class CombinedPolypDataset(Dataset): """ 合并多个息肉数据集用于训练 """ def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False): self.datasets = [] # 支持的数据集目录 dataset_dirs = ['CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'Kvasir'] for dataset_dir in dataset_dirs: full_path = os.path.join(data_path, dataset_dir) if os.path.exists(full_path): ds = PolypDataset(args, full_path, transform, transform_msk, mode, prompt, plane) if len(ds) > 0: self.datasets.append(ds) print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})") # 计算累积长度 self.cumulative_sizes = [] total = 0 for ds in self.datasets: total += len(ds) self.cumulative_sizes.append(total) print(f"Total {mode} samples: {total}") def __len__(self): return self.cumulative_sizes[-1] if self.cumulative_sizes else 0 def __getitem__(self, index): # 找到对应的数据集 for i, cumsize in enumerate(self.cumulative_sizes): if index < cumsize: if i == 0: return self.datasets[i][index] else: return self.datasets[i][index - self.cumulative_sizes[i-1]] raise IndexError("Index out of range") class REFUGE(Dataset): def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): self.data_path = data_path self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()] self.mode = mode self.prompt = prompt self.img_size = args.image_size self.mask_size = args.out_size self.transform = transform self.transform_msk = transform_msk def __len__(self): return len(self.subfolders) def __getitem__(self, index): inout = 1 point_label = 1 """Get the images""" subfolder = self.subfolders[index] name = subfolder.split('/')[-1] # raw image and raters path img_path = os.path.join(subfolder, name + '.jpg') multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)] multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)] # raw image and raters images img = Image.open(img_path).convert('RGB') multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path] multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path] # resize raters images for generating initial point click newsize = (self.img_size, self.img_size) multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup] multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc] # first click is the target agreement among all raters if self.prompt == 'click': pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout) pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout) if self.transform: state = torch.get_rng_state() img = self.transform(img) multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup] multi_rater_cup = torch.stack(multi_rater_cup, dim=0) # transform to mask size (out_size) for mask define mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc] multi_rater_disc = torch.stack(multi_rater_disc, dim=0) mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) torch.set_rng_state(state) image_meta_dict = {'filename_or_obj':name} return { 'image':img, 'multi_rater_cup': multi_rater_cup, 'multi_rater_disc': multi_rater_disc, 'mask_cup': mask_cup, 'mask_disc': mask_disc, 'label': mask_disc, 'p_label':point_label, 'pt_cup':pt_cup, 'pt_disc':pt_disc, 'pt':pt_disc, 'selected_rater': torch.tensor(np.arange(7)), 'image_meta_dict':image_meta_dict, }