You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
4.0 KiB
111 lines
4.0 KiB
6 months ago
|
import os
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.font_manager as fm
|
||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||
|
from tensorflow.keras import layers, models
|
||
|
import tensorflow as tf
|
||
|
|
||
|
# 设置字体
|
||
|
font_path = 'C:/Windows/Fonts/msyh.ttc'
|
||
|
prop = fm.FontProperties(fname=font_path)
|
||
|
plt.rcParams['font.family'] = prop.get_name()
|
||
|
|
||
|
# 设置数据目录
|
||
|
data_dir = 'D:/hand/archive'
|
||
|
train_dir = os.path.join(data_dir, 'leapGestRecog')
|
||
|
validation_dir = os.path.join(data_dir, 'validation')
|
||
|
|
||
|
# 检查目录路径是否存在
|
||
|
print("训练数据目录存在:", os.path.exists(train_dir))
|
||
|
print("验证数据目录存在:", os.path.exists(validation_dir))
|
||
|
|
||
|
if os.path.exists(train_dir) and os.path.exists(validation_dir):
|
||
|
# 简单的标准化处理
|
||
|
datagen = ImageDataGenerator(rescale=1./255)
|
||
|
|
||
|
train_generator = datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
|
||
|
validation_generator = datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='categorical')
|
||
|
|
||
|
train_dataset = tf.data.Dataset.from_generator(
|
||
|
lambda: train_generator,
|
||
|
output_signature=(
|
||
|
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
|
||
|
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
|
||
|
)
|
||
|
).repeat()
|
||
|
|
||
|
validation_dataset = tf.data.Dataset.from_generator(
|
||
|
lambda: validation_generator,
|
||
|
output_signature=(
|
||
|
tf.TensorSpec(shape=(None, 150, 150, 3), dtype=tf.float32),
|
||
|
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
|
||
|
)
|
||
|
).repeat()
|
||
|
|
||
|
# 检查类别名称和数量
|
||
|
train_classes = train_generator.class_indices
|
||
|
validation_classes = validation_generator.class_indices
|
||
|
|
||
|
print("训练数据集类别:", train_classes)
|
||
|
print("验证数据集类别:", validation_classes)
|
||
|
|
||
|
# 构建模型
|
||
|
model = models.Sequential([
|
||
|
layers.Input(shape=(150, 150, 3)),
|
||
|
layers.Conv2D(32, (3, 3), activation='relu'),
|
||
|
layers.MaxPooling2D((2, 2)),
|
||
|
layers.Conv2D(64, (3, 3), activation='relu'),
|
||
|
layers.MaxPooling2D((2, 2)),
|
||
|
layers.Conv2D(128, (3, 3), activation='relu'),
|
||
|
layers.MaxPooling2D((2, 2)),
|
||
|
layers.Conv2D(128, (3, 3), activation='relu'),
|
||
|
layers.MaxPooling2D((2, 2)),
|
||
|
layers.Flatten(),
|
||
|
layers.Dense(512, activation='relu'),
|
||
|
layers.Dense(train_generator.num_classes, activation='softmax')
|
||
|
])
|
||
|
|
||
|
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
||
|
|
||
|
steps_per_epoch = train_generator.samples // train_generator.batch_size
|
||
|
validation_steps = validation_generator.samples // validation_generator.batch_size
|
||
|
|
||
|
history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch,
|
||
|
epochs=30, validation_data=validation_dataset,
|
||
|
validation_steps=validation_steps)
|
||
|
|
||
|
validation_loss, validation_accuracy = model.evaluate(validation_dataset, steps=validation_steps)
|
||
|
print(f"Validation Accuracy: {validation_accuracy * 100:.2f}%")
|
||
|
|
||
|
# 保存模型
|
||
|
model.save('hand_gesture_model.h5')
|
||
|
print("模型已保存到 hand_gesture_model.h5")
|
||
|
|
||
|
# 从文件加载模型
|
||
|
loaded_model = load_model('hand_gesture_model.h5')
|
||
|
print("模型已从 hand_gesture_model.h5 加载")
|
||
|
|
||
|
# 可视化训练过程
|
||
|
acc = history.history['accuracy']
|
||
|
val_acc = history.history['val_accuracy']
|
||
|
loss = history.history['loss']
|
||
|
val_loss = history.history['val_loss']
|
||
|
|
||
|
epochs = range(len(acc))
|
||
|
|
||
|
plt.plot(epochs, acc, 'r', label='训练准确度')
|
||
|
plt.plot(epochs, val_acc, 'b', label='验证准确度')
|
||
|
plt.title('训练和验证准确度')
|
||
|
plt.legend()
|
||
|
|
||
|
plt.figure()
|
||
|
|
||
|
plt.plot(epochs, loss, 'r', label='训练损失')
|
||
|
plt.plot(epochs, val_loss, 'b', label='验证损失')
|
||
|
plt.title('训练和验证损失')
|
||
|
plt.legend()
|
||
|
|
||
|
plt.show()
|
||
|
else:
|
||
|
print(f"路径错误,请检查以下路径是否正确:\n训练数据目录: {train_dir}\n验证数据目录: {validation_dir}")
|