You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

27 lines
1002 B

from 口罩检测.generate_data import custom_reader
from 口罩检测.util import train_parameters
import paddle as paddle
import paddle.fluid as fluid
from 口罩检测.VGGNet import VGGNet
import numpy as np
eval_reader = paddle.batch(custom_reader(train_parameters['eval_list_path']),
batch_size=train_parameters['train_batch_size'],
drop_last=True)
with fluid.dygraph.guard():
model,_ =fluid.load_dygraph('vgg')
vgg =VGGNet()
vgg.eval()
accs=[]
for batch_id,data in enumerate(eval_reader()):
x_data = np.array([x[0] for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64')
y_data = y_data[:,np.newaxis]
img = fluid.dygraph.to_variable(x_data)
label =fluid.dygraph.to_variable(y_data)
out,acc= vgg(img,label)
label = np.argmax(out.numpy())
accs.append(acc.numpy()[0])
print(np.mean(accs))