|
|
import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import tkinter as tk
|
|
|
from tkinter import filedialog
|
|
|
from read_data import *
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
class ModelObj: # 网络对象
|
|
|
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
|
|
|
self.ObjID = ObjID # 图元号
|
|
|
self.ObjType = ObjType # 图元类别
|
|
|
self.ObjLable = ObjLable # 对象标签
|
|
|
self.ParaString = ParaString # 参数字符串
|
|
|
self.ObjX = ObjX # 对象位置x坐标
|
|
|
self.ObjY = ObjY # 对象位置y坐标
|
|
|
def image_to_array(path, height, width):
|
|
|
img = Image.open(path).convert("L") # 转换为灰度模式
|
|
|
img = img.resize((height, width)) # 将图片缩放为height*width的固定大小
|
|
|
data = np.array(img) # 将图片转换为数组格式
|
|
|
data = data / 255.0 # 将数组中的数值归一化,除以255
|
|
|
return data
|
|
|
class Data_Class(ModelObj): # 数据集网络对象
|
|
|
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
|
|
|
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
|
|
|
self.LoadData = self.load_data # 基本操作函数
|
|
|
self.SetDataPara = self.SetLoadData # 参数设置函数
|
|
|
def load_data(self, DataPara): # 定义加载数据集load_data()
|
|
|
global SubFolders
|
|
|
listimages = [] # 存储图片的列表
|
|
|
list_path = []
|
|
|
# read_folders方法读取文件夹中子文件夹中数据集
|
|
|
SubFolders, train = read_folders(DataPara["train_imgPath"])
|
|
|
list_path.append(train)
|
|
|
_, path_list = read_folders(DataPara["test_imgPath"])
|
|
|
list_path.append(path_list)
|
|
|
for data_path in list_path:
|
|
|
images = [] # 存储图片的列表
|
|
|
for path in data_path:
|
|
|
img = image_to_array(path, DataPara["img_width"],
|
|
|
DataPara["img_height"]) # 读取图片数据
|
|
|
img = img.T # 转置,图像的行和列将互换位置。
|
|
|
images.append(img) # 将图片数组添加到训练集图片列表中
|
|
|
listimages.append(np.array(images)) # 返回转换后的数组
|
|
|
return listimages[0], listimages[1]
|
|
|
def SetLoadData(self):# 定义加载数据集的参数SetLoadData()
|
|
|
# 设置数据集路径信息
|
|
|
# 训练集文件夹的位置
|
|
|
train_imgPath = input("请输入训练集文件夹的位置:") # 'data_classification/train/'
|
|
|
# 测试集文件夹的位置
|
|
|
test_imgPath = input("请输入测试集文件夹的位置:") # 'data_classification/test/'
|
|
|
img_width = int(input("请输入图片宽度:")) # 48
|
|
|
img_height = int(input("请输入图片高度:")) # 48
|
|
|
# 设置每批次读入图片的数量
|
|
|
batch_size = int(input("请输入每批次读入图片的数量:")) # 批次大小 32
|
|
|
# 返回DataPara参数,这里用一个字典来存储
|
|
|
DataPara = {"train_imgPath": train_imgPath,
|
|
|
"test_imgPath": test_imgPath,
|
|
|
"img_width": img_width,
|
|
|
"img_height": img_height,
|
|
|
"batch_size": batch_size}
|
|
|
return DataPara
|
|
|
if __name__ == '__main__':
|
|
|
DataSet = Data_Class("DataSet1", 1, "数据集1", [], 120, 330)
|
|
|
# setload_data()函数,获取加载数据集的参数
|
|
|
DataPara = DataSet.SetDataPara()
|
|
|
train_images, test_images = DataSet.LoadData(DataPara)
|