You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.4 KiB
66 lines
2.4 KiB
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
import os
|
|
from PIL import Image
|
|
|
|
class LungXrayDataset(Dataset):
|
|
def __init__(self, root_dir, is_train=True, transform=None):
|
|
"""
|
|
参数:
|
|
root_dir (str): 数据集根目录
|
|
is_train (bool): 是否为训练集
|
|
transform (callable, optional): 可选的图像预处理
|
|
"""
|
|
self.root_dir = root_dir
|
|
self.is_train = is_train
|
|
self.transform = transform
|
|
self.classes = ['Covid', 'Normal', 'Viral Pneumonia']
|
|
|
|
# 设置基础图像变换
|
|
if self.transform is None:
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((256, 256)), # 调整图像大小
|
|
transforms.Grayscale(num_output_channels=1), # 转换为灰度图
|
|
transforms.ToTensor(), # 转换为tensor
|
|
transforms.Normalize(mean=[0.5], std=[0.5]) # 针对灰度图的标准化
|
|
])
|
|
|
|
# 收集数据路径和标签
|
|
self.data_info = []
|
|
for class_idx, class_name in enumerate(self.classes):
|
|
class_path = os.path.join(root_dir, 'train' if is_train else 'noisy_test',class_name)
|
|
for img_name in os.listdir(class_path):
|
|
if img_name.endswith(('.png', '.jpg', '.jpeg')):
|
|
self.data_info.append({
|
|
'path': os.path.join(class_path, img_name),
|
|
'label': class_idx
|
|
})
|
|
|
|
def add_gaussian_noise(self, image):
|
|
"""添加高斯噪声"""
|
|
noise = torch.randn_like(image) * 0.1 # 0.1是噪声强度,可以调整
|
|
noisy_image = image + noise
|
|
return torch.clamp(noisy_image, 0, 1)
|
|
|
|
def __len__(self):
|
|
return len(self.data_info)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path = self.data_info[idx]['path']
|
|
label = self.data_info[idx]['label']
|
|
|
|
# 加载图像
|
|
image = Image.open(img_path).convert('L') # 直接以灰度图方式加载
|
|
image = self.transform(image)
|
|
|
|
# 如果是训练集,添加高斯噪声
|
|
if self.is_train:
|
|
image = self.add_gaussian_noise(image)
|
|
return image, label
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_dataset = LungXrayDataset(root_dir = "dataset", is_train = False)
|
|
|
|
print(train_dataset[0][0].shape) |