You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.3 KiB

10 months ago
import paddle.fluid as fluid
from 口罩检测.VGGNet import VGGNet
import paddle as paddle
from 口罩检测.util import train_parameters,draw_process,draw_train_process
from 口罩检测.generate_data import custom_reader
import numpy as np
all_train_iter=0
all_train_iters=[]
all_train_costs=[]
all_train_accs=[]
train_reader = paddle.batch(custom_reader(train_parameters['train_list_path']),
batch_size=train_parameters['train_batch_size'],
drop_last=True)
with fluid.dygraph.guard():
print(train_parameters['class_dim'])
print(train_parameters['label_dict'])
vgg =VGGNet()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=train_parameters['learning_stategry']['lr'],parameter_list=vgg.parameters())
for epoch_num in range(train_parameters['num_epochs']):
for batch_id,data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype('float32')
label = np.array([x[1] for x in data]).astype('int64')
label = label[:,np.newaxis]
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(label)
out,acc = vgg(img,label)
loss = fluid.layers.cross_entropy(out,label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
vgg.clear_gradients()
all_train_iter = all_train_iter + train_parameters['train_batch_size']
all_train_iters.append(all_train_iter)
all_train_costs.append(loss.numpy()[0])
all_train_accs.append(acc.numpy()[0])
if batch_id % 1 == 0:
print(
"Loss at epoch {} step {}: {}, acc: {}".format(epoch_num, batch_id, avg_loss.numpy(), acc.numpy()))
draw_train_process("training", all_train_iters, all_train_costs, all_train_accs, "trainning cost",
"trainning acc")
draw_process("trainning loss", "red", all_train_iters, all_train_costs, "trainning loss")
draw_process("trainning acc", "green", all_train_iters, all_train_accs, "trainning acc")
# 保存模型参数
fluid.save_dygraph(vgg.state_dict(), "vgg")
print("Final loss: {}".format(avg_loss.numpy()))