You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.4 KiB

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
class LungXrayDataset(Dataset):
def __init__(self, root_dir, is_train=True, transform=None):
"""
参数:
root_dir (str): 数据集根目录
is_train (bool): 是否为训练集
transform (callable, optional): 可选的图像预处理
"""
self.root_dir = root_dir
self.is_train = is_train
self.transform = transform
self.classes = ['Covid', 'Normal', 'Viral Pneumonia']
# 设置基础图像变换
if self.transform is None:
self.transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整图像大小
transforms.Grayscale(num_output_channels=1), # 转换为灰度图
transforms.ToTensor(), # 转换为tensor
transforms.Normalize(mean=[0.5], std=[0.5]) # 针对灰度图的标准化
])
# 收集数据路径和标签
self.data_info = []
for class_idx, class_name in enumerate(self.classes):
class_path = os.path.join(root_dir, 'train' if is_train else 'noisy_test',class_name)
for img_name in os.listdir(class_path):
if img_name.endswith(('.png', '.jpg', '.jpeg')):
self.data_info.append({
'path': os.path.join(class_path, img_name),
'label': class_idx
})
def add_gaussian_noise(self, image):
"""添加高斯噪声"""
noise = torch.randn_like(image) * 0.1 # 0.1是噪声强度,可以调整
noisy_image = image + noise
return torch.clamp(noisy_image, 0, 1)
def __len__(self):
return len(self.data_info)
def __getitem__(self, idx):
img_path = self.data_info[idx]['path']
label = self.data_info[idx]['label']
# 加载图像
image = Image.open(img_path).convert('L') # 直接以灰度图方式加载
image = self.transform(image)
# 如果是训练集,添加高斯噪声
if self.is_train:
image = self.add_gaussian_noise(image)
return image, label
if __name__ == "__main__":
train_dataset = LungXrayDataset(root_dir = "dataset", is_train = False)
print(train_dataset[0][0].shape)