You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
1006 B
27 lines
1006 B
6 months ago
|
from 手势识别.generate_data import data_reader
|
||
|
from 手势识别.build_model import FullNet_Model
|
||
|
import paddle.fluid as fluid
|
||
|
import paddle as paddle
|
||
|
import numpy as np
|
||
|
test_reader = paddle.batch(reader=data_reader('./test_data.list'), batch_size=32)
|
||
|
with fluid.dygraph.guard():
|
||
|
accs=[]
|
||
|
model_dict,_=fluid.load_dygraph('FullNet_Model')
|
||
|
model = FullNet_Model()
|
||
|
model.load_dict(model_dict) #加载模型
|
||
|
model.eval() #开启评估模式
|
||
|
for batch_id,data in enumerate(test_reader()):
|
||
|
images = np.array([x[0].reshape(3,100,100) for x in data],np.float32)
|
||
|
labels = np.array([x[1] for x in data]).astype("int64")
|
||
|
labels = labels[:,np.newaxis]
|
||
|
|
||
|
image = fluid.dygraph.to_variable(images)
|
||
|
label = fluid.dygraph.to_variable(labels)
|
||
|
|
||
|
predict = model(image)
|
||
|
acc=fluid.layers.accuracy(predict,label)
|
||
|
accs.append(acc.numpy()[0])
|
||
|
avg_acc=np.mean(accs)
|
||
|
|
||
|
print('平均准确率为',avg_acc)
|