From fb5b2f172664cc87b322d02b528dc421b580d2db Mon Sep 17 00:00:00 2001 From: sztu202200202027 Date: Wed, 19 Jun 2024 00:38:54 +0800 Subject: [PATCH] ADD file via upload --- scenery_train.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 scenery_train.py diff --git a/scenery_train.py b/scenery_train.py new file mode 100644 index 0000000..9200f7c --- /dev/null +++ b/scenery_train.py @@ -0,0 +1,79 @@ +import os +import numpy as np +from PIL import Image +from tensorflow.keras.preprocessing.image import ImageDataGenerator +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense +import tensorflow as tf +from tensorflow.keras.callbacks import Callback, EarlyStopping + +# 数据路径 +train_data_dir = 'dataset/scenery/seg_train' +test_data_dir = 'dataset/scenery/seg_test' +batch_size = 32 + +# 图片生成器 +train_datagen = ImageDataGenerator(rescale=1./255) +test_datagen = ImageDataGenerator(rescale=1./255) + +train_generator = train_datagen.flow_from_directory( + train_data_dir, + target_size=(180, 180), + batch_size=batch_size, + class_mode='categorical' +) + +test_generator = test_datagen.flow_from_directory( + test_data_dir, + target_size=(180, 180), + batch_size=batch_size, + class_mode='categorical' +) + +# 配置GPU加速 +strategy = tf.distribute.MirroredStrategy() +print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + +# 配置早停 +early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) + +with strategy.scope(): + model = Sequential([ + Conv2D(32, (3, 3), activation='relu', input_shape=(180, 180, 3)), + MaxPooling2D((2, 2)), + Conv2D(64, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Conv2D(128, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Flatten(), + Dense(128, activation='relu'), + Dense(len(train_generator.class_indices), activation='softmax') + ]) + # 编译模型 + optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) + +# # 配置模型检查点,保存最优模型 +# checkpoint_path = "./model/animal_model.h5" +# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss', +# save_best_only=True, save_weights_only=False, verbose=1) + +# 训练模型 +model.fit(train_generator, epochs=10, validation_data=test_generator, callbacks=[early_stopping]) + +# 保存模型为 .h5 文件 +model.save("./model/scenery_model.h5", save_format='h5') + + +# model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) +# +# # 自定义回调函数,保存模型为 .h5 格式 +# class CustomModelCheckpoint(Callback): +# def on_epoch_end(self, epoch, logs=None): +# self.model.save("./model/scenery_model.h5") +# +# # 创建自定义回调函数实例 +# custom_checkpoint = CustomModelCheckpoint() +# +# # 训练模型时使用自定义回调函数 +# model.fit(train_generator, epochs=10, validation_data=test_generator, callbacks=[custom_checkpoint])