From 03f2ff4cde7ebae2a41d58c5e3302ed0b5b3aaba Mon Sep 17 00:00:00 2001 From: pb9238asq <1492831835@qq.com> Date: Sun, 28 Dec 2025 21:29:31 +0800 Subject: [PATCH] ADD file via upload --- dataset.py | 267 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 dataset.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..24ebcf7 --- /dev/null +++ b/dataset.py @@ -0,0 +1,267 @@ + +import os +import sys +import pickle +import cv2 +from skimage import io +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as transforms +import pandas as pd +from skimage.transform import rotate +from utils import random_click +import random +from monai.transforms import LoadImaged, Randomizable,LoadImage + + +class ISIC2016(Dataset): + def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): + + df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'), encoding='gbk') + self.name_list = df.iloc[:,1].tolist() + self.label_list = df.iloc[:,2].tolist() + self.data_path = data_path + self.mode = mode + self.prompt = prompt + self.img_size = args.image_size + + self.transform = transform + self.transform_msk = transform_msk + + def __len__(self): + return len(self.name_list) + + def __getitem__(self, index): + inout = 1 + point_label = 1 + + """Get the images""" + name = self.name_list[index] + img_path = os.path.join(self.data_path, name) + + mask_name = self.label_list[index] + msk_path = os.path.join(self.data_path, mask_name) + + img = Image.open(img_path).convert('RGB') + mask = Image.open(msk_path).convert('L') + + newsize = (self.img_size, self.img_size) + mask = mask.resize(newsize) + + if self.prompt == 'click': + pt = random_click(np.array(mask) / 255, point_label, inout) + + if self.transform: + state = torch.get_rng_state() + img = self.transform(img) + torch.set_rng_state(state) + + + if self.transform_msk: + mask = self.transform_msk(mask) + + + name = name.split('/')[-1].split(".jpg")[0] + image_meta_dict = {'filename_or_obj':name} + return { + 'image':img, + 'label': mask, + 'p_label':point_label, + 'pt':pt, + 'image_meta_dict':image_meta_dict, + } + +class PolypDataset(Dataset): + """ + 息肉分割数据集,支持 CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir 等 + 目录结构: + data_path/ + images/ + xxx.png + masks/ + xxx.png + """ + def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False): + self.data_path = data_path + self.mode = mode + self.prompt = prompt + self.img_size = args.image_size + self.out_size = args.out_size + + self.transform = transform + self.transform_msk = transform_msk + + # 获取图像列表 + img_dir = os.path.join(data_path, 'images') + self.img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]) + + # 按比例划分训练集和测试集 (80% 训练, 20% 测试) + split_idx = int(len(self.img_list) * 0.8) + if mode == 'Training': + self.img_list = self.img_list[:split_idx] + else: + self.img_list = self.img_list[split_idx:] + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, index): + point_label = 1 + inout = 1 + + # 获取图像和掩码路径 + img_name = self.img_list[index] + img_path = os.path.join(self.data_path, 'images', img_name) + mask_path = os.path.join(self.data_path, 'masks', img_name) + + # 读取图像和掩码 + img = Image.open(img_path).convert('RGB') + mask = Image.open(mask_path).convert('L') + + # 调整掩码大小用于生成点击提示 + newsize = (self.img_size, self.img_size) + mask_resized = mask.resize(newsize) + + # 生成点击提示 + if self.prompt == 'click': + pt = random_click(np.array(mask_resized) / 255, point_label, inout) + + # 应用变换 + if self.transform: + state = torch.get_rng_state() + img = self.transform(img) + torch.set_rng_state(state) + + if self.transform_msk: + mask = self.transform_msk(mask) + + name = img_name.split('.')[0] + image_meta_dict = {'filename_or_obj': name} + + return { + 'image': img, + 'label': mask, + 'p_label': point_label, + 'pt': pt, + 'image_meta_dict': image_meta_dict, + } + + +class CombinedPolypDataset(Dataset): + """ + 合并多个息肉数据集用于训练 + """ + def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False): + self.datasets = [] + + # 支持的数据集目录 + dataset_dirs = ['CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'Kvasir'] + + for dataset_dir in dataset_dirs: + full_path = os.path.join(data_path, dataset_dir) + if os.path.exists(full_path): + ds = PolypDataset(args, full_path, transform, transform_msk, mode, prompt, plane) + if len(ds) > 0: + self.datasets.append(ds) + print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})") + + # 计算累积长度 + self.cumulative_sizes = [] + total = 0 + for ds in self.datasets: + total += len(ds) + self.cumulative_sizes.append(total) + + print(f"Total {mode} samples: {total}") + + def __len__(self): + return self.cumulative_sizes[-1] if self.cumulative_sizes else 0 + + def __getitem__(self, index): + # 找到对应的数据集 + for i, cumsize in enumerate(self.cumulative_sizes): + if index < cumsize: + if i == 0: + return self.datasets[i][index] + else: + return self.datasets[i][index - self.cumulative_sizes[i-1]] + raise IndexError("Index out of range") + + +class REFUGE(Dataset): + def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): + self.data_path = data_path + self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()] + self.mode = mode + self.prompt = prompt + self.img_size = args.image_size + self.mask_size = args.out_size + + self.transform = transform + self.transform_msk = transform_msk + + def __len__(self): + return len(self.subfolders) + + def __getitem__(self, index): + inout = 1 + point_label = 1 + + """Get the images""" + subfolder = self.subfolders[index] + name = subfolder.split('/')[-1] + + # raw image and raters path + img_path = os.path.join(subfolder, name + '.jpg') + multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)] + multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)] + + # raw image and raters images + img = Image.open(img_path).convert('RGB') + multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path] + multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path] + + # resize raters images for generating initial point click + newsize = (self.img_size, self.img_size) + multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup] + multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc] + + # first click is the target agreement among all raters + if self.prompt == 'click': + pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout) + pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout) + + if self.transform: + state = torch.get_rng_state() + img = self.transform(img) + multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup] + multi_rater_cup = torch.stack(multi_rater_cup, dim=0) + # transform to mask size (out_size) for mask define + mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) + + multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc] + multi_rater_disc = torch.stack(multi_rater_disc, dim=0) + mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) + torch.set_rng_state(state) + + image_meta_dict = {'filename_or_obj':name} + return { + 'image':img, + 'multi_rater_cup': multi_rater_cup, + 'multi_rater_disc': multi_rater_disc, + 'mask_cup': mask_cup, + 'mask_disc': mask_disc, + 'label': mask_disc, + 'p_label':point_label, + 'pt_cup':pt_cup, + 'pt_disc':pt_disc, + 'pt':pt_disc, + 'selected_rater': torch.tensor(np.arange(7)), + 'image_meta_dict':image_meta_dict, + } + +