You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
3.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import cv2
import numpy as np
import tkinter as tk
from tkinter import filedialog
from read_data import *
from PIL import Image
class ModelObj: # 网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
def image_to_array(path, height, width):
img = Image.open(path).convert("L") # 转换为灰度模式
img = img.resize((height, width)) # 将图片缩放为height*width的固定大小
data = np.array(img) # 将图片转换为数组格式
data = data / 255.0 # 将数组中的数值归一化除以255
return data
class Data_Class(ModelObj): # 数据集网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
self.LoadData = self.load_data # 基本操作函数
self.SetDataPara = self.SetLoadData # 参数设置函数
def load_data(self, DataPara): # 定义加载数据集load_data()
global SubFolders
listimages = [] # 存储图片的列表
list_path = []
# read_folders方法读取文件夹中子文件夹中数据集
SubFolders, train = read_folders(DataPara["train_imgPath"])
list_path.append(train)
_, path_list = read_folders(DataPara["test_imgPath"])
list_path.append(path_list)
for data_path in list_path:
images = [] # 存储图片的列表
for path in data_path:
img = image_to_array(path, DataPara["img_width"],
DataPara["img_height"]) # 读取图片数据
img = img.T # 转置,图像的行和列将互换位置。
images.append(img) # 将图片数组添加到训练集图片列表中
listimages.append(np.array(images)) # 返回转换后的数组
return listimages[0], listimages[1]
def SetLoadData(self):# 定义加载数据集的参数SetLoadData()
# 设置数据集路径信息
# 训练集文件夹的位置
train_imgPath = input("请输入训练集文件夹的位置:") # 'data_classification/train/'
# 测试集文件夹的位置
test_imgPath = input("请输入测试集文件夹的位置:") # 'data_classification/test/'
img_width = int(input("请输入图片宽度:")) # 48
img_height = int(input("请输入图片高度:")) # 48
# 设置每批次读入图片的数量
batch_size = int(input("请输入每批次读入图片的数量:")) # 批次大小 32
# 返回DataPara参数这里用一个字典来存储
DataPara = {"train_imgPath": train_imgPath,
"test_imgPath": test_imgPath,
"img_width": img_width,
"img_height": img_height,
"batch_size": batch_size}
return DataPara
if __name__ == '__main__':
DataSet = Data_Class("DataSet1", 1, "数据集1", [], 120, 330)
# setload_data()函数,获取加载数据集的参数
DataPara = DataSet.SetDataPara()
train_images, test_images = DataSet.LoadData(DataPara)