You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
2.5 KiB

import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from sklearn.utils import class_weight
# 数据路径
data_dir = 'dataset/animal' # 数据集根目录
batch_size = 16
# 图片生成器,用于从文件夹加载图片数据
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.15,
zoom_range=0.1,
horizontal_flip=True
)
generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
# 计算样本权重
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
# 配置早停
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
# 配置GPU加速
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
# 构建更复杂的模型
model = tf.keras.Sequential([
Conv2D(16, (3, 3), activation='relu', input_shape=(180, 180, 3)),
MaxPooling2D((2, 2)),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(len(generator.class_indices), activation='softmax')
])
# 编译模型
optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# # 配置模型检查点,保存最优模型
# checkpoint_path = "./model/animal_model.h5"
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
# save_best_only=True, save_weights_only=False, verbose=1)
# 训练模型
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
# 保存模型为 .h5 文件
model.save("./model/animal_model.h5", save_format='h5')