ADD file via upload

main
sztu202200202027 5 months ago
parent 92337dadcb
commit 985584d778

@ -0,0 +1,90 @@
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, RandomFlip, RandomRotation, RandomZoom, Attention
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from sklearn.utils import class_weight
# 数据路径
data_dir = 'dataset/flowers' # 数据集根目录
batch_size = 32
# 图片生成器,用于从文件夹加载图片数据
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=10,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
zoom_range=0.2,
horizontal_flip=True,
)
generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
data_dir,
target_size=(180, 180),
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
# 计算样本权重
class_weights = class_weight.compute_sample_weight('balanced', generator.classes)
# 配置早停
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
# # 数据增强
# data_augmentation = tf.keras.Sequential(
# [
# RandomFlip("horizontal", input_shape=(180, 180, 3)),
# RandomRotation(0.1),
# RandomZoom(0.1),
# ]
# )
# 配置GPU加速
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
model = tf.keras.Sequential([
# data_augmentation,
Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(180, 180, 3)),
MaxPooling2D(),
Conv2D(64, (3, 3), padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(128, (3, 3), padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(128, activation='relu'),
Dense(len(generator.class_indices), activation='softmax')
])
# 编译模型
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# # 配置模型检查点,保存最优模型
# checkpoint_path = "./model/animal_model.h5"
# model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss',
# save_best_only=True, save_weights_only=False, verbose=1)
# 训练模型
model.fit(generator, epochs=20, validation_data=validation_generator, callbacks=[early_stopping])
# 保存模型为 .h5 文件
model.save("./model/flower_model.h5", save_format='h5')
Loading…
Cancel
Save