You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
4.0 KiB

6 months ago
import os
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import tensorflow as tf
# 设置字体
font_path = 'C:/Windows/Fonts/msyh.ttc'
prop = fm.FontProperties(fname=font_path)
plt.rcParams['font.family'] = prop.get_name()
# 设置数据目录
data_dir = 'D:/hand/archive'
train_dir = os.path.join(data_dir, 'leapGestRecog')
validation_dir = os.path.join(data_dir, 'validation')
# 检查目录路径是否存在
print("训练数据目录存在:", os.path.exists(train_dir))
print("验证数据目录存在:", os.path.exists(validation_dir))
if os.path.exists(train_dir) and os.path.exists(validation_dir):
# 简单的标准化处理
datagen = ImageDataGenerator(rescale=1./255)
train_generator = datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
validation_generator = datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
train_dataset = tf.data.Dataset.from_generator(
lambda: train_generator,
output_signature=(
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
)
).repeat()
validation_dataset = tf.data.Dataset.from_generator(
lambda: validation_generator,
output_signature=(
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
)
).repeat()
# 检查类别名称和数量
train_classes = train_generator.class_indices
validation_classes = validation_generator.class_indices
print("训练数据集类别:", train_classes)
print("验证数据集类别:", validation_classes)
# 构建模型
model = models.Sequential([
layers.Input(shape=(150, 150, 3)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dense(train_generator.num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = validation_generator.samples // validation_generator.batch_size
history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch,
epochs=30, validation_data=validation_dataset,
validation_steps=validation_steps)
validation_loss, validation_accuracy = model.evaluate(validation_dataset, steps=validation_steps)
print(f"Validation Accuracy: {validation_accuracy * 100:.2f}%")
# 保存模型
model.save('hand_gesture_model.h5')
print("模型已保存到 hand_gesture_model.h5")
# 从文件加载模型
loaded_model = load_model('hand_gesture_model.h5')
print("模型已从 hand_gesture_model.h5 加载")
# 可视化训练过程
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'r', label='训练准确度')
plt.plot(epochs, val_acc, 'b', label='验证准确度')
plt.title('训练和验证准确度')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'r', label='训练损失')
plt.plot(epochs, val_loss, 'b', label='验证损失')
plt.title('训练和验证损失')
plt.legend()
plt.show()
else:
print(f"路径错误,请检查以下路径是否正确:\n训练数据目录: {train_dir}\n验证数据目录: {validation_dir}")