You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

35 lines
1.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

class ModelObj: # 网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
class Data_Class(ModelObj): # 数据集网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
# self.LoadData = self.load_data # 基本操作函数 -------------------------
self.SetDataPara = self.SetLoadData # 参数设置函数
def SetLoadData(self):# 定义加载数据集的参数SetLoadData()
# 设置数据集路径信息
# 训练集文件夹的位置
train_imgPath = input("请输入训练集文件夹的位置:") # 'data_classification/train/'
# 测试集文件夹的位置
test_imgPath = input("请输入测试集文件夹的位置:") # 'data_classification/test/'
img_width = int(input("请输入图片宽度:")) # 48
img_height = int(input("请输入图片高度:")) # 48
# 设置每批次读入图片的数量
batch_size = int(input("请输入每批次读入图片的数量:")) # 批次大小 32
# 返回DataPara参数这里用一个字典来存储
DataPara = {"train_imgPath": train_imgPath,
"test_imgPath": test_imgPath,
"img_width": img_width,
"img_height": img_height,
"batch_size": batch_size}
return DataPara
if __name__ == '__main__':
DataSet = Data_Class("DataSet1", 1, "数据集1", [], 120, 330)
# setload_data()函数,获取加载数据集的参数
DataPara = DataSet.SetDataPara()