You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.3 KiB

import paddle.fluid as fluid
from 口罩检测.VGGNet import VGGNet
import paddle as paddle
from 口罩检测.util import train_parameters,draw_process,draw_train_process
from 口罩检测.generate_data import custom_reader
import numpy as np
all_train_iter=0
all_train_iters=[]
all_train_costs=[]
all_train_accs=[]
train_reader = paddle.batch(custom_reader(train_parameters['train_list_path']),
batch_size=train_parameters['train_batch_size'],
drop_last=True)
with fluid.dygraph.guard():
print(train_parameters['class_dim'])
print(train_parameters['label_dict'])
vgg =VGGNet()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=train_parameters['learning_stategry']['lr'],parameter_list=vgg.parameters())
for epoch_num in range(train_parameters['num_epochs']):
for batch_id,data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype('float32')
label = np.array([x[1] for x in data]).astype('int64')
label = label[:,np.newaxis]
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(label)
out,acc = vgg(img,label)
loss = fluid.layers.cross_entropy(out,label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
vgg.clear_gradients()
all_train_iter = all_train_iter + train_parameters['train_batch_size']
all_train_iters.append(all_train_iter)
all_train_costs.append(loss.numpy()[0])
all_train_accs.append(acc.numpy()[0])
if batch_id % 1 == 0:
print(
"Loss at epoch {} step {}: {}, acc: {}".format(epoch_num, batch_id, avg_loss.numpy(), acc.numpy()))
draw_train_process("training", all_train_iters, all_train_costs, all_train_accs, "trainning cost",
"trainning acc")
draw_process("trainning loss", "red", all_train_iters, all_train_costs, "trainning loss")
draw_process("trainning acc", "green", all_train_iters, all_train_accs, "trainning acc")
# 保存模型参数
fluid.save_dygraph(vgg.state_dict(), "vgg")
print("Final loss: {}".format(avg_loss.numpy()))