import os import matplotlib.pyplot as plt import matplotlib.font_manager as fm from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras import layers, models import tensorflow as tf # 设置字体 font_path = 'C:/Windows/Fonts/msyh.ttc' prop = fm.FontProperties(fname=font_path) plt.rcParams['font.family'] = prop.get_name() # 设置数据目录 data_dir = 'D:/hand/archive' train_dir = os.path.join(data_dir, 'leapGestRecog') validation_dir = os.path.join(data_dir, 'validation') # 检查目录路径是否存在 print("训练数据目录存在:", os.path.exists(train_dir)) print("验证数据目录存在:", os.path.exists(validation_dir)) if os.path.exists(train_dir) and os.path.exists(validation_dir): # 简单的标准化处理 datagen = ImageDataGenerator(rescale=1./255) train_generator = datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical') validation_generator = datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='categorical') train_dataset = tf.data.Dataset.from_generator( lambda: train_generator, output_signature=( tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32) ) ).repeat() validation_dataset = tf.data.Dataset.from_generator( lambda: validation_generator, output_signature=( tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32) ) ).repeat() # 检查类别名称和数量 train_classes = train_generator.class_indices validation_classes = validation_generator.class_indices print("训练数据集类别:", train_classes) print("验证数据集类别:", validation_classes) # 构建模型 model = models.Sequential([ layers.Input(shape=(150, 150, 3)), layers.Conv2D(32, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dense(512, activation='relu'), layers.Dense(train_generator.num_classes, activation='softmax') ]) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) steps_per_epoch = train_generator.samples // train_generator.batch_size validation_steps = validation_generator.samples // validation_generator.batch_size history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch, epochs=30, validation_data=validation_dataset, validation_steps=validation_steps) validation_loss, validation_accuracy = model.evaluate(validation_dataset, steps=validation_steps) print(f"Validation Accuracy: {validation_accuracy * 100:.2f}%") # 保存模型 model.save('hand_gesture_model.h5') print("模型已保存到 hand_gesture_model.h5") # 从文件加载模型 loaded_model = load_model('hand_gesture_model.h5') print("模型已从 hand_gesture_model.h5 加载") # 可视化训练过程 acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'r', label='训练准确度') plt.plot(epochs, val_acc, 'b', label='验证准确度') plt.title('训练和验证准确度') plt.legend() plt.figure() plt.plot(epochs, loss, 'r', label='训练损失') plt.plot(epochs, val_loss, 'b', label='验证损失') plt.title('训练和验证损失') plt.legend() plt.show() else: print(f"路径错误,请检查以下路径是否正确:\n训练数据目录: {train_dir}\n验证数据目录: {validation_dir}")