You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.9 KiB
91 lines
2.9 KiB
5 months ago
|
import os
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.keras import layers
|
||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, RandomFlip, RandomRotation, RandomZoom, Attention
|
||
|
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
||
|
from sklearn.utils import class_weight
|
||
|
|
||
|
|
||
|
# 数据路径
|
||
|
data_dir = 'dataset/flowers' # 数据集根目录
|
||
|
batch_size = 32
|
||
|
|
||
|
# 图片生成器,用于从文件夹加载图片数据
|
||
|
datagen = ImageDataGenerator(
|
||
|
rescale=1./255,
|
||
|
validation_split=0.2,
|
||
|
rotation_range=10,
|
||
|
width_shift_range=0.2,
|
||
|
height_shift_range=0.2,
|
||
|
shear_range=0.15,
|
||
|
zoom_range=0.2,
|
||
|
horizontal_flip=True,
|
||
|
)
|
||
|
|
||
|
generator = datagen.flow_from_directory(
|
||
|
data_dir,
|
||
|
target_size=(180, 180),
|
||
|
batch_size=batch_size,
|
||
|
class_mode='categorical',
|
||
|
subset='training'
|
||
|
)
|
||
|
|
||
|
validation_generator = datagen.flow_from_directory(
|
||
|
data_dir,
|
||
|
target_size=(180, 180),
|
||
|
batch_size=batch_size,
|
||
|
class_mode='categorical',
|
||
|
subset='validation'
|
||
|
)
|
||
|
|
||
|
# 计算样本权重
|
||
|
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
|
||
|
|
||
|
# 配置早停
|
||
|
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
|
||
|
|
||
|
# # 数据增强
|
||
|
# data_augmentation = tf.keras.Sequential(
|
||
|
# [
|
||
|
# RandomFlip("horizontal", input_shape=(180, 180, 3)),
|
||
|
# RandomRotation(0.1),
|
||
|
# RandomZoom(0.1),
|
||
|
# ]
|
||
|
# )
|
||
|
|
||
|
# 配置GPU加速
|
||
|
strategy = tf.distribute.MirroredStrategy()
|
||
|
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||
|
|
||
|
with strategy.scope():
|
||
|
model = tf.keras.Sequential([
|
||
|
# data_augmentation,
|
||
|
Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(180, 180, 3)),
|
||
|
MaxPooling2D(),
|
||
|
Conv2D(64, (3, 3), padding='same', activation='relu'),
|
||
|
MaxPooling2D(),
|
||
|
Conv2D(128, (3, 3), padding='same', activation='relu'),
|
||
|
MaxPooling2D(),
|
||
|
Flatten(),
|
||
|
Dense(128, activation='relu'),
|
||
|
Dense(len(generator.class_indices), activation='softmax')
|
||
|
])
|
||
|
|
||
|
# 编译模型
|
||
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
||
|
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
|
||
|
|
||
|
# # 配置模型检查点,保存最优模型
|
||
|
# checkpoint_path = "./model/animal_model.h5"
|
||
|
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
|
||
|
# save_best_only=True, save_weights_only=False, verbose=1)
|
||
|
|
||
|
# 训练模型
|
||
|
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
|
||
|
|
||
|
# 保存模型为 .h5 文件
|
||
|
model.save("./model/flower_model.h5", save_format='h5')
|