parent
91dd26172e
commit
101fd8cef0
@ -0,0 +1,53 @@
|
||||
import paddle.fluid as fluid
|
||||
from 口罩检测.VGGNet import VGGNet
|
||||
import paddle as paddle
|
||||
from 口罩检测.util import train_parameters,draw_process,draw_train_process
|
||||
from 口罩检测.generate_data import custom_reader
|
||||
import numpy as np
|
||||
|
||||
all_train_iter=0
|
||||
all_train_iters=[]
|
||||
all_train_costs=[]
|
||||
all_train_accs=[]
|
||||
|
||||
train_reader = paddle.batch(custom_reader(train_parameters['train_list_path']),
|
||||
batch_size=train_parameters['train_batch_size'],
|
||||
drop_last=True)
|
||||
with fluid.dygraph.guard():
|
||||
print(train_parameters['class_dim'])
|
||||
print(train_parameters['label_dict'])
|
||||
vgg =VGGNet()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=train_parameters['learning_stategry']['lr'],parameter_list=vgg.parameters())
|
||||
for epoch_num in range(train_parameters['num_epochs']):
|
||||
for batch_id,data in enumerate(train_reader()):
|
||||
x_data = np.array([x[0] for x in data]).astype('float32')
|
||||
label = np.array([x[1] for x in data]).astype('int64')
|
||||
label = label[:,np.newaxis]
|
||||
|
||||
img = fluid.dygraph.to_variable(x_data)
|
||||
label = fluid.dygraph.to_variable(label)
|
||||
out,acc = vgg(img,label)
|
||||
loss = fluid.layers.cross_entropy(out,label)
|
||||
avg_loss = fluid.layers.mean(loss)
|
||||
|
||||
avg_loss.backward()
|
||||
optimizer.minimize(avg_loss)
|
||||
vgg.clear_gradients()
|
||||
|
||||
all_train_iter = all_train_iter + train_parameters['train_batch_size']
|
||||
all_train_iters.append(all_train_iter)
|
||||
all_train_costs.append(loss.numpy()[0])
|
||||
all_train_accs.append(acc.numpy()[0])
|
||||
|
||||
if batch_id % 1 == 0:
|
||||
print(
|
||||
"Loss at epoch {} step {}: {}, acc: {}".format(epoch_num, batch_id, avg_loss.numpy(), acc.numpy()))
|
||||
|
||||
draw_train_process("training", all_train_iters, all_train_costs, all_train_accs, "trainning cost",
|
||||
"trainning acc")
|
||||
draw_process("trainning loss", "red", all_train_iters, all_train_costs, "trainning loss")
|
||||
draw_process("trainning acc", "green", all_train_iters, all_train_accs, "trainning acc")
|
||||
|
||||
# 保存模型参数
|
||||
fluid.save_dygraph(vgg.state_dict(), "vgg")
|
||||
print("Final loss: {}".format(avg_loss.numpy()))
|
Loading…
Reference in new issue