parent
92337dadcb
commit
985584d778
@ -0,0 +1,90 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, RandomFlip, RandomRotation, RandomZoom, Attention
|
||||||
|
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
||||||
|
from sklearn.utils import class_weight
|
||||||
|
|
||||||
|
|
||||||
|
# 数据路径
|
||||||
|
data_dir = 'dataset/flowers' # 数据集根目录
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
|
# 图片生成器,用于从文件夹加载图片数据
|
||||||
|
datagen = ImageDataGenerator(
|
||||||
|
rescale=1./255,
|
||||||
|
validation_split=0.2,
|
||||||
|
rotation_range=10,
|
||||||
|
width_shift_range=0.2,
|
||||||
|
height_shift_range=0.2,
|
||||||
|
shear_range=0.15,
|
||||||
|
zoom_range=0.2,
|
||||||
|
horizontal_flip=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = datagen.flow_from_directory(
|
||||||
|
data_dir,
|
||||||
|
target_size=(180, 180),
|
||||||
|
batch_size=batch_size,
|
||||||
|
class_mode='categorical',
|
||||||
|
subset='training'
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_generator = datagen.flow_from_directory(
|
||||||
|
data_dir,
|
||||||
|
target_size=(180, 180),
|
||||||
|
batch_size=batch_size,
|
||||||
|
class_mode='categorical',
|
||||||
|
subset='validation'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算样本权重
|
||||||
|
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
|
||||||
|
|
||||||
|
# 配置早停
|
||||||
|
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
|
||||||
|
|
||||||
|
# # 数据增强
|
||||||
|
# data_augmentation = tf.keras.Sequential(
|
||||||
|
# [
|
||||||
|
# RandomFlip("horizontal", input_shape=(180, 180, 3)),
|
||||||
|
# RandomRotation(0.1),
|
||||||
|
# RandomZoom(0.1),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 配置GPU加速
|
||||||
|
strategy = tf.distribute.MirroredStrategy()
|
||||||
|
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
# data_augmentation,
|
||||||
|
Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(180, 180, 3)),
|
||||||
|
MaxPooling2D(),
|
||||||
|
Conv2D(64, (3, 3), padding='same', activation='relu'),
|
||||||
|
MaxPooling2D(),
|
||||||
|
Conv2D(128, (3, 3), padding='same', activation='relu'),
|
||||||
|
MaxPooling2D(),
|
||||||
|
Flatten(),
|
||||||
|
Dense(128, activation='relu'),
|
||||||
|
Dense(len(generator.class_indices), activation='softmax')
|
||||||
|
])
|
||||||
|
|
||||||
|
# 编译模型
|
||||||
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
||||||
|
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
|
||||||
|
|
||||||
|
# # 配置模型检查点,保存最优模型
|
||||||
|
# checkpoint_path = "./model/animal_model.h5"
|
||||||
|
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
|
||||||
|
# save_best_only=True, save_weights_only=False, verbose=1)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
|
||||||
|
|
||||||
|
# 保存模型为 .h5 文件
|
||||||
|
model.save("./model/flower_model.h5", save_format='h5')
|
Loading…
Reference in new issue