From 5bb4574553faac59983c890b05b83a066bf7175e Mon Sep 17 00:00:00 2001 From: p6tnmhi3f <2598669852@qq.com> Date: Sun, 30 Jun 2024 22:09:59 +0800 Subject: [PATCH] =?UTF-8?q?Add=20lyh=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lyh部分 | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 lyh部分 diff --git a/lyh部分 b/lyh部分 new file mode 100644 index 0000000..307e909 --- /dev/null +++ b/lyh部分 @@ -0,0 +1,61 @@ +import tensorflow as tf +from tensorflow.keras.models import load_model, Sequential +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +def create_model(): + model = Sequential([ + Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), + MaxPooling2D((2, 2)), + Conv2D(64, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Flatten(), + Dense(128, activation='relu'), + Dropout(0.5), + Dense(10, activation='softmax') # 假设有10个交通标志类别 + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + return model + +def train_model(model, train_data_dir, validation_data_dir): + train_datagen = ImageDataGenerator(rescale=1./255) + val_datagen = ImageDataGenerator(rescale=1./255) + + train_generator = train_datagen.flow_from_directory( + train_data_dir, + target_size=(640, 640), + batch_size=32, + class_mode='sparse' + ) + + validation_generator = val_datagen.flow_from_directory( + validation_data_dir, + target_size=(640, 640), + batch_size=32, + class_mode='sparse' + ) + + model.fit(train_generator, epochs=10, validation_data=validation_generator) + +# 示例 +model = create_model() +train_data_dir = 'path/to/train/data' +validation_data_dir = 'path/to/validation/data' +train_model(model, train_data_dir, validation_data_dir) + +# 保存模型 +model.save('path/to/your/model.h5') + +# 加载模型并预测 +model = load_model('path/to/your/model.h5') + +def predict_traffic_sign(image): + # 扩展维度以适配模型输入 + image = np.expand_dims(image, axis=0) + # 模型预测 + predictions = model.predict(image) + return predictions + +# 示例 +predictions = predict_traffic_sign(preprocessed_image) +predicted_class = np.argmax(predictions, axis=1) \ No newline at end of file