From 5e826454078ee49bda175cc7176a360f57eda4b7 Mon Sep 17 00:00:00 2001 From: pyhqos7bg Date: Wed, 29 May 2024 16:57:18 +0800 Subject: [PATCH] ADD file via upload --- train_model.py | 103 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 train_model.py diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..77fc99e --- /dev/null +++ b/train_model.py @@ -0,0 +1,103 @@ +import numpy as np +import os +import paddle as paddle +import paddle.fluid as fluid +from PIL import Image +import cv2 +import matplotlib.pyplot as plt +from multiprocessing import cpu_count +from paddle.fluid.dygraph import Pool2D,Conv2D,Linear +data_path ='./data_set' +train_data = "./train_data.list" +test_data = './test_data.list' +character_folders = os.listdir(data_path) +label = 0 +label_temp = {} +if(os.path.exists(train_data)): + os.remove(train_data) +if(os.path.exists(test_data)): + os.remove(test_data) +for character_folder in character_folders: + with open(train_data,'a') as file_train: + with open(test_data,'a') as file_test: + if character_folder == '.DS_Store' or character_folder == '.ipynb_checkpoints' or character_folder == 'data23617': + continue + # print(character_folder+str(label)) + label_temp[str(label)] =character_folder + character_imgs = os.listdir(os.path.join(data_path,character_folder)) + for i in range(len(character_imgs)): + filePath = data_path + "/" + character_folder+"/"+character_imgs[i] + if i % 10 == 0: + file_test.write(filePath+ "\t" + str(label) + '\n') + else: + file_train.write(filePath + "\t" + str(label) + '\n') + label = label + 1 +print('图像列表已生成') + +#定义图像列表定义字符训练集和测试集的reader + +def data_mapper(sample): + imgpath, label = sample + img = paddle.dataset.image.load_image(file=imgpath, is_color=False) + img = np.array(img).astype('float32')/ 255.0 + print(imgpath,img.shape) + img =img.reshape(1,20,20) + return img, label +def data_reader(data_list_path): + def reader(): + with open(data_list_path,'r') as f: + lines= f.readlines() + for line in lines: + img,label = line.split('\t') + yield img,int(label) + return paddle.reader.xmap_readers(data_mapper,reader,cpu_count(),1024) + +# 用于训练的数据提供器 +train_reader = paddle.batch(reader=paddle.reader.shuffle(reader=data_reader('./train_data.list'), buf_size=512), batch_size=64) +# 用于测试的数据提供器 +test_reader = paddle.batch(reader=data_reader('./test_data.list'), batch_size=10) + +class LeNet_Model(fluid.dygraph.Layer): + def __init__(self): + super(LeNet_Model,self).__init__() + self.hidden1_1 = Conv2D(num_channels=1,num_filters=28,filter_size=5,stride=1) + self.hidden1_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1) + self.hidden2_1 = Conv2D(num_channels=28,num_filters=32,filter_size=3,stride=1) + self.hidden2_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1) + self.hidden3=Conv2D(num_channels=32,num_filters=32,filter_size=3,stride=1) + self.hidden4=Linear(32*10*10,64,act='softmax') + def forward(self, input): + x= self.hidden1_1(input) + x =self.hidden1_2(x) + x = self.hidden2_1(x) + x=self.hidden2_2(x) + x =self.hidden3(x) + x= fluid.layers.reshape(x,[-1,32*10*10]) + y= self.hidden4(x) + return y +#开启训练模式 +with fluid.dygraph.gard(): + model = LeNet_Model() #实例化模型 + model.train() #开启训练模式 + opt = fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters()) + eporchs_num = 2 + for pass_num in range(eporchs_num): + for batch_id,data in enumerate(train_reader()): + images = np.array([x[0] for x in data],np.float32) + labels = np.array([x[1] for x in data]).astype('int64') + labels = labels[:, np.newaxis] + image = fluid.dygraph.to_variable(images) + label = fluid.dygraph.to_variable(labels) + + predict = model(image) + loss = fluid.layers.cross_entropy(predict,label) + avg_loss= fluid.layers.mean(loss) #获取精度的平均值 + + acc = fluid.layers.accuracy(predict,label) + if batch_id!=0 and batch_id%50==0: + print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num, batch_id, avg_loss.numpy(), + acc.numpy())) + avg_loss.backward() + opt.minimize(avg_loss) + model.clear_gradients() + fluid.save_dygraph(model.state_dict(),"letnet_model")