You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
9.7 KiB

import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from utils import random_click
import random
from monai.transforms import LoadImaged, Randomizable,LoadImage
class ISIC2016(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'), encoding='gbk')
self.name_list = df.iloc[:,1].tolist()
self.label_list = df.iloc[:,2].tolist()
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
name = self.name_list[index]
img_path = os.path.join(self.data_path, name)
mask_name = self.label_list[index]
msk_path = os.path.join(self.data_path, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
pt = random_click(np.array(mask) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}
class PolypDataset(Dataset):
"""
息肉分割数据集,支持 CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir 等
目录结构:
data_path/
images/
xxx.png
masks/
xxx.png
"""
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False):
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.out_size = args.out_size
self.transform = transform
self.transform_msk = transform_msk
# 获取图像列表
img_dir = os.path.join(data_path, 'images')
self.img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
# 按比例划分训练集和测试集 (80% 训练, 20% 测试)
split_idx = int(len(self.img_list) * 0.8)
if mode == 'Training':
self.img_list = self.img_list[:split_idx]
else:
self.img_list = self.img_list[split_idx:]
def __len__(self):
return len(self.img_list)
def __getitem__(self, index):
point_label = 1
inout = 1
# 获取图像和掩码路径
img_name = self.img_list[index]
img_path = os.path.join(self.data_path, 'images', img_name)
mask_path = os.path.join(self.data_path, 'masks', img_name)
# 读取图像和掩码
img = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L')
# 调整掩码大小用于生成点击提示
newsize = (self.img_size, self.img_size)
mask_resized = mask.resize(newsize)
# 生成点击提示
if self.prompt == 'click':
pt = random_click(np.array(mask_resized) / 255, point_label, inout)
# 应用变换
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
name = img_name.split('.')[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
class CombinedPolypDataset(Dataset):
"""
合并多个息肉数据集用于训练
"""
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False):
self.datasets = []
# 支持的数据集目录
dataset_dirs = ['CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'Kvasir']
for dataset_dir in dataset_dirs:
full_path = os.path.join(data_path, dataset_dir)
if os.path.exists(full_path):
ds = PolypDataset(args, full_path, transform, transform_msk, mode, prompt, plane)
if len(ds) > 0:
self.datasets.append(ds)
print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})")
# 计算累积长度
self.cumulative_sizes = []
total = 0
for ds in self.datasets:
total += len(ds)
self.cumulative_sizes.append(total)
print(f"Total {mode} samples: {total}")
def __len__(self):
return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
def __getitem__(self, index):
# 找到对应的数据集
for i, cumsize in enumerate(self.cumulative_sizes):
if index < cumsize:
if i == 0:
return self.datasets[i][index]
else:
return self.datasets[i][index - self.cumulative_sizes[i-1]]
raise IndexError("Index out of range")
class REFUGE(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
self.data_path = data_path
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()]
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mask_size = args.out_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.subfolders)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
subfolder = self.subfolders[index]
name = subfolder.split('/')[-1]
# raw image and raters path
img_path = os.path.join(subfolder, name + '.jpg')
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)]
multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]
# raw image and raters images
img = Image.open(img_path).convert('RGB')
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
# resize raters images for generating initial point click
newsize = (self.img_size, self.img_size)
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]
# first click is the target agreement among all raters
if self.prompt == 'click':
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup]
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
# transform to mask size (out_size) for mask define
mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc]
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
torch.set_rng_state(state)
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater_cup': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_disc,
'p_label':point_label,
'pt_cup':pt_cup,
'pt_disc':pt_disc,
'pt':pt_disc,
'selected_rater': torch.tensor(np.arange(7)),
'image_meta_dict':image_meta_dict,
}