diff --git a/lyh部分 b/lyh部分 new file mode 100644 index 0000000..307e909 --- /dev/null +++ b/lyh部分 @@ -0,0 +1,61 @@ +import tensorflow as tf +from tensorflow.keras.models import load_model, Sequential +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +def create_model(): + model = Sequential([ + Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), + MaxPooling2D((2, 2)), + Conv2D(64, (3, 3), activation='relu'), + MaxPooling2D((2, 2)), + Flatten(), + Dense(128, activation='relu'), + Dropout(0.5), + Dense(10, activation='softmax') # 假设有10个交通标志类别 + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + return model + +def train_model(model, train_data_dir, validation_data_dir): + train_datagen = ImageDataGenerator(rescale=1./255) + val_datagen = ImageDataGenerator(rescale=1./255) + + train_generator = train_datagen.flow_from_directory( + train_data_dir, + target_size=(640, 640), + batch_size=32, + class_mode='sparse' + ) + + validation_generator = val_datagen.flow_from_directory( + validation_data_dir, + target_size=(640, 640), + batch_size=32, + class_mode='sparse' + ) + + model.fit(train_generator, epochs=10, validation_data=validation_generator) + +# 示例 +model = create_model() +train_data_dir = 'path/to/train/data' +validation_data_dir = 'path/to/validation/data' +train_model(model, train_data_dir, validation_data_dir) + +# 保存模型 +model.save('path/to/your/model.h5') + +# 加载模型并预测 +model = load_model('path/to/your/model.h5') + +def predict_traffic_sign(image): + # 扩展维度以适配模型输入 + image = np.expand_dims(image, axis=0) + # 模型预测 + predictions = model.predict(image) + return predictions + +# 示例 +predictions = predict_traffic_sign(preprocessed_image) +predicted_class = np.argmax(predictions, axis=1) \ No newline at end of file