ADD file via upload

main
sztu202200202027 5 months ago
parent 19da1f9579
commit 04368e0919

@ -0,0 +1,80 @@
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from sklearn.utils import class_weight
# 数据路径
data_dir = 'dataset/animal' # 数据集根目录
batch_size = 16
# 图片生成器,用于从文件夹加载图片数据
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.15,
zoom_range=0.1,
horizontal_flip=True
)
generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
# 计算样本权重
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
# 配置早停
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
# 配置GPU加速
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
# 构建更复杂的模型
model = tf.keras.Sequential([
Conv2D(16, (3, 3), activation='relu', input_shape=(180, 180, 3)),
MaxPooling2D((2, 2)),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(len(generator.class_indices), activation='softmax')
])
# 编译模型
optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# # 配置模型检查点,保存最优模型
# checkpoint_path = "./model/animal_model.h5"
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
# save_best_only=True, save_weights_only=False, verbose=1)
# 训练模型
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
# 保存模型为 .h5 文件
model.save("./model/animal_model.h5", save_format='h5')
Loading…
Cancel
Save