You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.4 KiB
104 lines
4.4 KiB
import numpy as np
|
|
import os
|
|
import paddle as paddle
|
|
import paddle.fluid as fluid
|
|
from PIL import Image
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
from multiprocessing import cpu_count
|
|
from paddle.fluid.dygraph import Pool2D,Conv2D,Linear
|
|
data_path ='./data_set'
|
|
train_data = "./train_data.list"
|
|
test_data = './test_data.list'
|
|
character_folders = os.listdir(data_path)
|
|
label = 0
|
|
label_temp = {}
|
|
if(os.path.exists(train_data)):
|
|
os.remove(train_data)
|
|
if(os.path.exists(test_data)):
|
|
os.remove(test_data)
|
|
for character_folder in character_folders:
|
|
with open(train_data,'a') as file_train:
|
|
with open(test_data,'a') as file_test:
|
|
if character_folder == '.DS_Store' or character_folder == '.ipynb_checkpoints' or character_folder == 'data23617':
|
|
continue
|
|
# print(character_folder+str(label))
|
|
label_temp[str(label)] =character_folder
|
|
character_imgs = os.listdir(os.path.join(data_path,character_folder))
|
|
for i in range(len(character_imgs)):
|
|
filePath = data_path + "/" + character_folder+"/"+character_imgs[i]
|
|
if i % 10 == 0:
|
|
file_test.write(filePath+ "\t" + str(label) + '\n')
|
|
else:
|
|
file_train.write(filePath + "\t" + str(label) + '\n')
|
|
label = label + 1
|
|
print('图像列表已生成')
|
|
|
|
#定义图像列表定义字符训练集和测试集的reader
|
|
|
|
def data_mapper(sample):
|
|
imgpath, label = sample
|
|
img = paddle.dataset.image.load_image(file=imgpath, is_color=False)
|
|
img = np.array(img).astype('float32')/ 255.0
|
|
print(imgpath,img.shape)
|
|
img =img.reshape(1,20,20)
|
|
return img, label
|
|
def data_reader(data_list_path):
|
|
def reader():
|
|
with open(data_list_path,'r') as f:
|
|
lines= f.readlines()
|
|
for line in lines:
|
|
img,label = line.split('\t')
|
|
yield img,int(label)
|
|
return paddle.reader.xmap_readers(data_mapper,reader,cpu_count(),1024)
|
|
|
|
# 用于训练的数据提供器
|
|
train_reader = paddle.batch(reader=paddle.reader.shuffle(reader=data_reader('./train_data.list'), buf_size=512), batch_size=64)
|
|
# 用于测试的数据提供器
|
|
test_reader = paddle.batch(reader=data_reader('./test_data.list'), batch_size=10)
|
|
|
|
class LeNet_Model(fluid.dygraph.Layer):
|
|
def __init__(self):
|
|
super(LeNet_Model,self).__init__()
|
|
self.hidden1_1 = Conv2D(num_channels=1,num_filters=28,filter_size=5,stride=1)
|
|
self.hidden1_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1)
|
|
self.hidden2_1 = Conv2D(num_channels=28,num_filters=32,filter_size=3,stride=1)
|
|
self.hidden2_2 = Pool2D(pool_size=2,pool_type='max',pool_stride=1)
|
|
self.hidden3=Conv2D(num_channels=32,num_filters=32,filter_size=3,stride=1)
|
|
self.hidden4=Linear(32*10*10,64,act='softmax')
|
|
def forward(self, input):
|
|
x= self.hidden1_1(input)
|
|
x =self.hidden1_2(x)
|
|
x = self.hidden2_1(x)
|
|
x=self.hidden2_2(x)
|
|
x =self.hidden3(x)
|
|
x= fluid.layers.reshape(x,[-1,32*10*10])
|
|
y= self.hidden4(x)
|
|
return y
|
|
#开启训练模式
|
|
with fluid.dygraph.gard():
|
|
model = LeNet_Model() #实例化模型
|
|
model.train() #开启训练模式
|
|
opt = fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())
|
|
eporchs_num = 2
|
|
for pass_num in range(eporchs_num):
|
|
for batch_id,data in enumerate(train_reader()):
|
|
images = np.array([x[0] for x in data],np.float32)
|
|
labels = np.array([x[1] for x in data]).astype('int64')
|
|
labels = labels[:, np.newaxis]
|
|
image = fluid.dygraph.to_variable(images)
|
|
label = fluid.dygraph.to_variable(labels)
|
|
|
|
predict = model(image)
|
|
loss = fluid.layers.cross_entropy(predict,label)
|
|
avg_loss= fluid.layers.mean(loss) #获取精度的平均值
|
|
|
|
acc = fluid.layers.accuracy(predict,label)
|
|
if batch_id!=0 and batch_id%50==0:
|
|
print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num, batch_id, avg_loss.numpy(),
|
|
acc.numpy()))
|
|
avg_loss.backward()
|
|
opt.minimize(avg_loss)
|
|
model.clear_gradients()
|
|
fluid.save_dygraph(model.state_dict(),"letnet_model")
|