parent
4dc09d8d10
commit
5e82645407
@ -0,0 +1,103 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import paddle as paddle
|
||||
import paddle.fluid as fluid
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import cpu_count
|
||||
from paddle.fluid.dygraph import Pool2D,Conv2D,Linear
|
||||
data_path ='./data_set'
|
||||
train_data = "./train_data.list"
|
||||
test_data = './test_data.list'
|
||||
character_folders = os.listdir(data_path)
|
||||
label = 0
|
||||
label_temp = {}
|
||||
if(os.path.exists(train_data)):
|
||||
os.remove(train_data)
|
||||
if(os.path.exists(test_data)):
|
||||
os.remove(test_data)
|
||||
for character_folder in character_folders:
|
||||
with open(train_data,'a') as file_train:
|
||||
with open(test_data,'a') as file_test:
|
||||
if character_folder == '.DS_Store' or character_folder == '.ipynb_checkpoints' or character_folder == 'data23617':
|
||||
continue
|
||||
# print(character_folder+str(label))
|
||||
label_temp[str(label)] =character_folder
|
||||
character_imgs = os.listdir(os.path.join(data_path,character_folder))
|
||||
for i in range(len(character_imgs)):
|
||||
filePath = data_path + "/" + character_folder+"/"+character_imgs[i]
|
||||
if i % 10 == 0:
|
||||
file_test.write(filePath+ "\t" + str(label) + '\n')
|
||||
else:
|
||||
file_train.write(filePath + "\t" + str(label) + '\n')
|
||||
label = label + 1
|
||||
print('图像列表已生成')
|
||||
|
||||
#定义图像列表定义字符训练集和测试集的reader
|
||||
|
||||
def data_mapper(sample):
|
||||
imgpath, label = sample
|
||||
img = paddle.dataset.image.load_image(file=imgpath, is_color=False)
|
||||
img = np.array(img).astype('float32')/ 255.0
|
||||
print(imgpath,img.shape)
|
||||
img =img.reshape(1,20,20)
|
||||
return img, label
|
||||
def data_reader(data_list_path):
|
||||
def reader():
|
||||
with open(data_list_path,'r') as f:
|
||||
lines= f.readlines()
|
||||
for line in lines:
|
||||
img,label = line.split('\t')
|
||||
yield img,int(label)
|
||||
return paddle.reader.xmap_readers(data_mapper,reader,cpu_count(),1024)
|
||||
|
||||
# 用于训练的数据提供器
|
||||
train_reader = paddle.batch(reader=paddle.reader.shuffle(reader=data_reader('./train_data.list'), buf_size=512), batch_size=64)
|
||||
# 用于测试的数据提供器
|
||||
test_reader = paddle.batch(reader=data_reader('./test_data.list'), batch_size=10)
|
||||
|
||||
class LeNet_Model(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(LeNet_Model,self).__init__()
|
||||
self.hidden1_1 = Conv2D(num_channels=1,num_filters=28,filter_size=5,stride=1)
|
||||
self.hidden1_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1)
|
||||
self.hidden2_1 = Conv2D(num_channels=28,num_filters=32,filter_size=3,stride=1)
|
||||
self.hidden2_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1)
|
||||
self.hidden3=Conv2D(num_channels=32,num_filters=32,filter_size=3,stride=1)
|
||||
self.hidden4=Linear(32*10*10,64,act='softmax')
|
||||
def forward(self, input):
|
||||
x= self.hidden1_1(input)
|
||||
x =self.hidden1_2(x)
|
||||
x = self.hidden2_1(x)
|
||||
x=self.hidden2_2(x)
|
||||
x =self.hidden3(x)
|
||||
x= fluid.layers.reshape(x,[-1,32*10*10])
|
||||
y= self.hidden4(x)
|
||||
return y
|
||||
#开启训练模式
|
||||
with fluid.dygraph.gard():
|
||||
model = LeNet_Model() #实例化模型
|
||||
model.train() #开启训练模式
|
||||
opt = fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())
|
||||
eporchs_num = 2
|
||||
for pass_num in range(eporchs_num):
|
||||
for batch_id,data in enumerate(train_reader()):
|
||||
images = np.array([x[0] for x in data],np.float32)
|
||||
labels = np.array([x[1] for x in data]).astype('int64')
|
||||
labels = labels[:, np.newaxis]
|
||||
image = fluid.dygraph.to_variable(images)
|
||||
label = fluid.dygraph.to_variable(labels)
|
||||
|
||||
predict = model(image)
|
||||
loss = fluid.layers.cross_entropy(predict,label)
|
||||
avg_loss= fluid.layers.mean(loss) #获取精度的平均值
|
||||
|
||||
acc = fluid.layers.accuracy(predict,label)
|
||||
if batch_id!=0 and batch_id%50==0:
|
||||
print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num, batch_id, avg_loss.numpy(),
|
||||
acc.numpy()))
|
||||
avg_loss.backward()
|
||||
opt.minimize(avg_loss)
|
||||
model.clear_gradients()
|
||||
fluid.save_dygraph(model.state_dict(),"letnet_model")
|
Loading…
Reference in new issue