parent
972e7eda9c
commit
43a8e334fe
@ -0,0 +1,52 @@
|
||||
from 手势识别.generate_data import data_reader
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.dygraph import Linear
|
||||
import paddle as paddle
|
||||
import numpy as np
|
||||
train_reader = paddle.batch(reader=paddle.reader.shuffle(reader=data_reader('./train_data.list'), buf_size=256), batch_size=32)
|
||||
class FullNet_Model(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(FullNet_Model,self).__init__()
|
||||
self.hidden1 = Linear(input_dim=100,output_dim=100,act='relu')
|
||||
self.hidden2 = Linear(input_dim=100,output_dim=100,act='relu')
|
||||
self.hidden3 = Linear(input_dim=100,output_dim=100,act='relu')
|
||||
self.hidden4 = Linear(input_dim=3*100*100,output_dim=10,act='softmax')
|
||||
def forward(self, input):
|
||||
x = self.hidden1(input)
|
||||
x= self.hidden2(x)
|
||||
x = self.hidden3(x)
|
||||
x= fluid.layers.reshape(x,shape=[-1,3*100*100])
|
||||
y =self.hidden4(x)
|
||||
return y
|
||||
'''
|
||||
#用动态图进行训练
|
||||
with fluid.dygraph.guard():
|
||||
model = FullNet_Model()
|
||||
model.train()
|
||||
opt =fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())
|
||||
|
||||
epochs_num = 20 #设置迭代次数
|
||||
for epoch in range(epochs_num):
|
||||
for batch_id,data in enumerate(train_reader()):
|
||||
images = np.array([x[0].reshape(3,100,100) for x in data],np.float32)
|
||||
labels = np.array([x[1] for x in data]).astype('int64')
|
||||
labels = labels[:,np.newaxis]
|
||||
image = fluid.dygraph.to_variable(images)
|
||||
label = fluid.dygraph.to_variable(labels)
|
||||
|
||||
predict = model(image)
|
||||
|
||||
loss= fluid.layers.cross_entropy(predict,label)
|
||||
avg_loss= fluid.layers.mean(loss)
|
||||
acc = fluid.layers.accuracy(predict,label)
|
||||
if batch_id != 0 and batch_id % 50 == 0:
|
||||
print(
|
||||
"train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(epoch, batch_id, avg_loss.numpy(),
|
||||
acc.numpy()))
|
||||
|
||||
avg_loss.backward()
|
||||
opt.minimize(avg_loss)
|
||||
model.clear_gradients()
|
||||
|
||||
fluid.save_dygraph(model.state_dict(), 'FullNet_Model') # 保存模型
|
||||
'''
|
Loading…
Reference in new issue