From 04368e09190d854a0f19b5c0b988544d46858363 Mon Sep 17 00:00:00 2001 From: sztu202200202027 Date: Wed, 19 Jun 2024 00:35:26 +0800 Subject: [PATCH] ADD file via upload --- animal_train.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 animal_train.py diff --git a/animal_train.py b/animal_train.py new file mode 100644 index 0000000..2f50e90 --- /dev/null +++ b/animal_train.py @@ -0,0 +1,80 @@ +import os +import numpy as np +from PIL import Image +import tensorflow as tf +from tensorflow.keras.preprocessing.image import ImageDataGenerator +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback +from sklearn.utils import class_weight + + +# 数据路径 +data_dir = 'dataset/animal' # 数据集根目录 +batch_size = 16 + +# 图片生成器,用于从文件夹加载图片数据 +datagen = ImageDataGenerator( + rescale=1./255, + validation_split=0.2, + rotation_range=10, + width_shift_range=0.1, + height_shift_range=0.1, + shear_range=0.15, + zoom_range=0.1, + horizontal_flip=True +) + +generator = datagen.flow_from_directory( + data_dir, + target_size=(180, 180), + batch_size=batch_size, + class_mode='categorical', + subset='training' +) + +validation_generator = datagen.flow_from_directory( + data_dir, + target_size=(180, 180), + batch_size=batch_size, + class_mode='categorical', + subset='validation' +) + +# 计算样本权重 +class_weights = class_weight.compute_sample_weight('balanced', generator.classes) + +# 配置早停 +early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) + +# 配置GPU加速 +strategy = tf.distribute.MirroredStrategy() +print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + +with strategy.scope(): + # 构建更复杂的模型 + model = tf.keras.Sequential([ + Conv2D(16, (3, 3), activation='relu', input_shape=(180, 180, 3)), + MaxPooling2D((2, 2)), + Conv2D(32, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Conv2D(64, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Flatten(), + Dense(128, activation='relu'), + Dense(len(generator.class_indices), activation='softmax') + ]) + + # 编译模型 + optimizer = tf.keras.optimizers.Adam(learning_rate=0.005) + model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) + +# # 配置模型检查点,保存最优模型 +# checkpoint_path = "./model/animal_model.h5" +# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss', +# save_best_only=True, save_weights_only=False, verbose=1) + +# 训练模型 +model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping]) + +# 保存模型为 .h5 文件 +model.save("./model/animal_model.h5", save_format='h5') \ No newline at end of file