You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

35 lines
1.9 KiB

8 months ago
class ModelObj: # 网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
class Data_Class(ModelObj): # 数据集网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
# self.LoadData = self.load_data # 基本操作函数 -------------------------
self.SetDataPara = self.SetLoadData # 参数设置函数
def SetLoadData(self):# 定义加载数据集的参数SetLoadData()
# 设置数据集路径信息
# 训练集文件夹的位置
train_imgPath = input("请输入训练集文件夹的位置:") # 'data_classification/train/'
# 测试集文件夹的位置
test_imgPath = input("请输入测试集文件夹的位置:") # 'data_classification/test/'
img_width = int(input("请输入图片宽度:")) # 48
img_height = int(input("请输入图片高度:")) # 48
# 设置每批次读入图片的数量
batch_size = int(input("请输入每批次读入图片的数量:")) # 批次大小 32
# 返回DataPara参数这里用一个字典来存储
DataPara = {"train_imgPath": train_imgPath,
"test_imgPath": test_imgPath,
"img_width": img_width,
"img_height": img_height,
"batch_size": batch_size}
return DataPara
if __name__ == '__main__':
DataSet = Data_Class("DataSet1", 1, "数据集1", [], 120, 330)
# setload_data()函数,获取加载数据集的参数
DataPara = DataSet.SetDataPara()