You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.1 KiB

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
import matplotlib.pyplot as plt
class ISBI_Loader(Dataset):
def __init__(self, data_path):
# 初始化函数读取所有data_path下的图片
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, '*.png'))
def augment(self, image, flipCode):
# 使用cv2.flip进行数据增强filpCode为1水平翻转0垂直翻转-1水平+垂直翻转
flip = cv2.flip(image, flipCode)
return flip
def __getitem__(self, index):
# 根据index读取图片
image_path = self.imgs_path[index]
# 根据image_path生成label_path
label_path = image_path.replace('normal_jpg', 'jpg')
# 读取训练图片和标签图片
image = cv2.imread(image_path)
label = cv2.imread(label_path)
# 将数据转为单通道的图片
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
image = image.reshape(1, image.shape[0], image.shape[1])
label = label.reshape(1, label.shape[0], label.shape[1])
# 处理标签将像素值为255的改为1
if label.max() > 1:
label = label / 255
# 随机进行数据增强为2时不做处理
flipCode = random.choice([-1, 0, 1, 2])
if flipCode != 2:
image = self.augment(image, flipCode)
label = self.augment(label, flipCode)
return image, label
def __len__(self):
# 返回训练集大小
return len(self.imgs_path)
if __name__ == "__main__":
isbi_dataset = ISBI_Loader(r'F:/unet_program/H/normal_jpg/')
print("数据个数:", len(isbi_dataset))
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=2,
shuffle=True)
for image, label in train_loader:
print(image.shape)