parent
92337dadcb
commit
985584d778
@ -0,0 +1,90 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, RandomFlip, RandomRotation, RandomZoom, Attention
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
||||
from sklearn.utils import class_weight
|
||||
|
||||
|
||||
# 数据路径
|
||||
data_dir = 'dataset/flowers' # 数据集根目录
|
||||
batch_size = 32
|
||||
|
||||
# 图片生成器,用于从文件夹加载图片数据
|
||||
datagen = ImageDataGenerator(
|
||||
rescale=1./255,
|
||||
validation_split=0.2,
|
||||
rotation_range=10,
|
||||
width_shift_range=0.2,
|
||||
height_shift_range=0.2,
|
||||
shear_range=0.15,
|
||||
zoom_range=0.2,
|
||||
horizontal_flip=True,
|
||||
)
|
||||
|
||||
generator = datagen.flow_from_directory(
|
||||
data_dir,
|
||||
target_size=(180, 180),
|
||||
batch_size=batch_size,
|
||||
class_mode='categorical',
|
||||
subset='training'
|
||||
)
|
||||
|
||||
validation_generator = datagen.flow_from_directory(
|
||||
data_dir,
|
||||
target_size=(180, 180),
|
||||
batch_size=batch_size,
|
||||
class_mode='categorical',
|
||||
subset='validation'
|
||||
)
|
||||
|
||||
# 计算样本权重
|
||||
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
|
||||
|
||||
# 配置早停
|
||||
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
|
||||
|
||||
# # 数据增强
|
||||
# data_augmentation = tf.keras.Sequential(
|
||||
# [
|
||||
# RandomFlip("horizontal", input_shape=(180, 180, 3)),
|
||||
# RandomRotation(0.1),
|
||||
# RandomZoom(0.1),
|
||||
# ]
|
||||
# )
|
||||
|
||||
# 配置GPU加速
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||
|
||||
with strategy.scope():
|
||||
model = tf.keras.Sequential([
|
||||
# data_augmentation,
|
||||
Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(180, 180, 3)),
|
||||
MaxPooling2D(),
|
||||
Conv2D(64, (3, 3), padding='same', activation='relu'),
|
||||
MaxPooling2D(),
|
||||
Conv2D(128, (3, 3), padding='same', activation='relu'),
|
||||
MaxPooling2D(),
|
||||
Flatten(),
|
||||
Dense(128, activation='relu'),
|
||||
Dense(len(generator.class_indices), activation='softmax')
|
||||
])
|
||||
|
||||
# 编译模型
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
||||
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
|
||||
|
||||
# # 配置模型检查点,保存最优模型
|
||||
# checkpoint_path = "./model/animal_model.h5"
|
||||
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
|
||||
# save_best_only=True, save_weights_only=False, verbose=1)
|
||||
|
||||
# 训练模型
|
||||
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
|
||||
|
||||
# 保存模型为 .h5 文件
|
||||
model.save("./model/flower_model.h5", save_format='h5')
|
Loading…
Reference in new issue