diff --git a/main b/main deleted file mode 100644 index 307e909..0000000 --- a/main +++ /dev/null @@ -1,61 +0,0 @@ -import tensorflow as tf -from tensorflow.keras.models import load_model, Sequential -from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout -from tensorflow.keras.preprocessing.image import ImageDataGenerator - -def create_model(): - model = Sequential([ - Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), - MaxPooling2D((2, 2)), - Conv2D(64, (3, 3), activation='relu'), - MaxPooling2D((2, 2)), - Flatten(), - Dense(128, activation='relu'), - Dropout(0.5), - Dense(10, activation='softmax') # 假设有10个交通标志类别 - ]) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - return model - -def train_model(model, train_data_dir, validation_data_dir): - train_datagen = ImageDataGenerator(rescale=1./255) - val_datagen = ImageDataGenerator(rescale=1./255) - - train_generator = train_datagen.flow_from_directory( - train_data_dir, - target_size=(640, 640), - batch_size=32, - class_mode='sparse' - ) - - validation_generator = val_datagen.flow_from_directory( - validation_data_dir, - target_size=(640, 640), - batch_size=32, - class_mode='sparse' - ) - - model.fit(train_generator, epochs=10, validation_data=validation_generator) - -# 示例 -model = create_model() -train_data_dir = 'path/to/train/data' -validation_data_dir = 'path/to/validation/data' -train_model(model, train_data_dir, validation_data_dir) - -# 保存模型 -model.save('path/to/your/model.h5') - -# 加载模型并预测 -model = load_model('path/to/your/model.h5') - -def predict_traffic_sign(image): - # 扩展维度以适配模型输入 - image = np.expand_dims(image, axis=0) - # 模型预测 - predictions = model.predict(image) - return predictions - -# 示例 -predictions = predict_traffic_sign(preprocessed_image) -predicted_class = np.argmax(predictions, axis=1) \ No newline at end of file