You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

43 lines
1.5 KiB

import os
import numpy as np
import matplotlib.pyplot as plt #绘画
from tensorflow.keras.preprocessing import image #图片预处理
from tensorflow.keras.models import load_model #加载模型
#生成图像数据
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 设置数据目录
test_dir = 'D:/hand/archive/leapGestRecog'
# 加载模型
model = load_model('hand_gesture_model.h5')
print("模型已从 hand_gesture_model.h5 加载")
# 使用 ImageDataGenerator 加载测试数据
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
# 评估模型在测试数据集上的性能
test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
# 加载单个图像并进行预测
img_path = 'D:/hand/archive/leapGestRecog/00/01_palm/frame_00_01_0001.png' # 修改为实际图像路径
img = image.load_img(img_path, target_size=(150, 150))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) / 255.0 # 标准化图像数据
# 进行预测
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions, axis=1)
# 输出预测结果
class_labels = {v: k for k, v in test_generator.class_indices.items()}
print(f"预测类别: {class_labels[predicted_class[0]]}")
# 显示图像和预测结果
plt.imshow(img)
plt.title(f"预测类别: {class_labels[predicted_class[0]]}")
plt.axis('off')
plt.show()