import tensorflow as tf from tensorflow.keras.models import load_model, Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout from tensorflow.keras.preprocessing.image import ImageDataGenerator def create_model(): model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), MaxPooling2D((2, 2)), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D((2, 2)), Flatten(), Dense(128, activation='relu'), Dropout(0.5), Dense(10, activation='softmax') # 假设有10个交通标志类别 ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) return model def train_model(model, train_data_dir, validation_data_dir): train_datagen = ImageDataGenerator(rescale=1./255) val_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_data_dir, target_size=(640, 640), batch_size=32, class_mode='sparse' ) validation_generator = val_datagen.flow_from_directory( validation_data_dir, target_size=(640, 640), batch_size=32, class_mode='sparse' ) model.fit(train_generator, epochs=10, validation_data=validation_generator) # 示例 model = create_model() train_data_dir = 'path/to/train/data' validation_data_dir = 'path/to/validation/data' train_model(model, train_data_dir, validation_data_dir) # 保存模型 model.save('path/to/your/model.h5') # 加载模型并预测 model = load_model('path/to/your/model.h5') def predict_traffic_sign(image): # 扩展维度以适配模型输入 image = np.expand_dims(image, axis=0) # 模型预测 predictions = model.predict(image) return predictions # 示例 predictions = predict_traffic_sign(preprocessed_image) predicted_class = np.argmax(predictions, axis=1)