You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.3 KiB
54 lines
2.3 KiB
import paddle.fluid as fluid
|
|
from 口罩检测.VGGNet import VGGNet
|
|
import paddle as paddle
|
|
from 口罩检测.util import train_parameters,draw_process,draw_train_process
|
|
from 口罩检测.generate_data import custom_reader
|
|
import numpy as np
|
|
|
|
all_train_iter=0
|
|
all_train_iters=[]
|
|
all_train_costs=[]
|
|
all_train_accs=[]
|
|
|
|
train_reader = paddle.batch(custom_reader(train_parameters['train_list_path']),
|
|
batch_size=train_parameters['train_batch_size'],
|
|
drop_last=True)
|
|
with fluid.dygraph.guard():
|
|
print(train_parameters['class_dim'])
|
|
print(train_parameters['label_dict'])
|
|
vgg =VGGNet()
|
|
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=train_parameters['learning_stategry']['lr'],parameter_list=vgg.parameters())
|
|
for epoch_num in range(train_parameters['num_epochs']):
|
|
for batch_id,data in enumerate(train_reader()):
|
|
x_data = np.array([x[0] for x in data]).astype('float32')
|
|
label = np.array([x[1] for x in data]).astype('int64')
|
|
label = label[:,np.newaxis]
|
|
|
|
img = fluid.dygraph.to_variable(x_data)
|
|
label = fluid.dygraph.to_variable(label)
|
|
out,acc = vgg(img,label)
|
|
loss = fluid.layers.cross_entropy(out,label)
|
|
avg_loss = fluid.layers.mean(loss)
|
|
|
|
avg_loss.backward()
|
|
optimizer.minimize(avg_loss)
|
|
vgg.clear_gradients()
|
|
|
|
all_train_iter = all_train_iter + train_parameters['train_batch_size']
|
|
all_train_iters.append(all_train_iter)
|
|
all_train_costs.append(loss.numpy()[0])
|
|
all_train_accs.append(acc.numpy()[0])
|
|
|
|
if batch_id % 1 == 0:
|
|
print(
|
|
"Loss at epoch {} step {}: {}, acc: {}".format(epoch_num, batch_id, avg_loss.numpy(), acc.numpy()))
|
|
|
|
draw_train_process("training", all_train_iters, all_train_costs, all_train_accs, "trainning cost",
|
|
"trainning acc")
|
|
draw_process("trainning loss", "red", all_train_iters, all_train_costs, "trainning loss")
|
|
draw_process("trainning acc", "green", all_train_iters, all_train_accs, "trainning acc")
|
|
|
|
# 保存模型参数
|
|
fluid.save_dygraph(vgg.state_dict(), "vgg")
|
|
print("Final loss: {}".format(avg_loss.numpy()))
|