You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
2.2 KiB
52 lines
2.2 KiB
6 months ago
|
from 手势识别.generate_data import data_reader
|
||
|
import paddle.fluid as fluid
|
||
|
from paddle.fluid.dygraph import Linear
|
||
|
import paddle as paddle
|
||
|
import numpy as np
|
||
|
train_reader = paddle.batch(reader=paddle.reader.shuffle(reader=data_reader('./train_data.list'), buf_size=256), batch_size=32)
|
||
|
class FullNet_Model(fluid.dygraph.Layer):
|
||
|
def __init__(self):
|
||
|
super(FullNet_Model,self).__init__()
|
||
|
self.hidden1 = Linear(input_dim=100,output_dim=100,act='relu')
|
||
|
self.hidden2 = Linear(input_dim=100,output_dim=100,act='relu')
|
||
|
self.hidden3 = Linear(input_dim=100,output_dim=100,act='relu')
|
||
|
self.hidden4 = Linear(input_dim=3*100*100,output_dim=10,act='softmax')
|
||
|
def forward(self, input):
|
||
|
x = self.hidden1(input)
|
||
|
x= self.hidden2(x)
|
||
|
x = self.hidden3(x)
|
||
|
x= fluid.layers.reshape(x,shape=[-1,3*100*100])
|
||
|
y =self.hidden4(x)
|
||
|
return y
|
||
|
'''
|
||
|
#用动态图进行训练
|
||
|
with fluid.dygraph.guard():
|
||
|
model = FullNet_Model()
|
||
|
model.train()
|
||
|
opt =fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())
|
||
|
|
||
|
epochs_num = 20 #设置迭代次数
|
||
|
for epoch in range(epochs_num):
|
||
|
for batch_id,data in enumerate(train_reader()):
|
||
|
images = np.array([x[0].reshape(3,100,100) for x in data],np.float32)
|
||
|
labels = np.array([x[1] for x in data]).astype('int64')
|
||
|
labels = labels[:,np.newaxis]
|
||
|
image = fluid.dygraph.to_variable(images)
|
||
|
label = fluid.dygraph.to_variable(labels)
|
||
|
|
||
|
predict = model(image)
|
||
|
|
||
|
loss= fluid.layers.cross_entropy(predict,label)
|
||
|
avg_loss= fluid.layers.mean(loss)
|
||
|
acc = fluid.layers.accuracy(predict,label)
|
||
|
if batch_id != 0 and batch_id % 50 == 0:
|
||
|
print(
|
||
|
"train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(epoch, batch_id, avg_loss.numpy(),
|
||
|
acc.numpy()))
|
||
|
|
||
|
avg_loss.backward()
|
||
|
opt.minimize(avg_loss)
|
||
|
model.clear_gradients()
|
||
|
|
||
|
fluid.save_dygraph(model.state_dict(), 'FullNet_Model') # 保存模型
|
||
|
'''
|