parent
							
								
									ede0508b33
								
							
						
					
					
						commit
						fc1c7c2c9b
					
				| @ -1,61 +0,0 @@ | |||||||
| import tensorflow as tf |  | ||||||
| from tensorflow.keras.models import load_model, Sequential |  | ||||||
| from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout |  | ||||||
| from tensorflow.keras.preprocessing.image import ImageDataGenerator |  | ||||||
| 
 |  | ||||||
| def create_model(): |  | ||||||
|     model = Sequential([ |  | ||||||
|         Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), |  | ||||||
|         MaxPooling2D((2, 2)), |  | ||||||
|         Conv2D(64, (3, 3), activation='relu'), |  | ||||||
|         MaxPooling2D((2, 2)), |  | ||||||
|         Flatten(), |  | ||||||
|         Dense(128, activation='relu'), |  | ||||||
|         Dropout(0.5), |  | ||||||
|         Dense(10, activation='softmax')  # 假设有10个交通标志类别 |  | ||||||
|     ]) |  | ||||||
|     model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) |  | ||||||
|     return model |  | ||||||
| 
 |  | ||||||
| def train_model(model, train_data_dir, validation_data_dir): |  | ||||||
|     train_datagen = ImageDataGenerator(rescale=1./255) |  | ||||||
|     val_datagen = ImageDataGenerator(rescale=1./255) |  | ||||||
| 
 |  | ||||||
|     train_generator = train_datagen.flow_from_directory( |  | ||||||
|         train_data_dir, |  | ||||||
|         target_size=(640, 640), |  | ||||||
|         batch_size=32, |  | ||||||
|         class_mode='sparse' |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|     validation_generator = val_datagen.flow_from_directory( |  | ||||||
|         validation_data_dir, |  | ||||||
|         target_size=(640, 640), |  | ||||||
|         batch_size=32, |  | ||||||
|         class_mode='sparse' |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|     model.fit(train_generator, epochs=10, validation_data=validation_generator) |  | ||||||
| 
 |  | ||||||
| # 示例 |  | ||||||
| model = create_model() |  | ||||||
| train_data_dir = 'path/to/train/data' |  | ||||||
| validation_data_dir = 'path/to/validation/data' |  | ||||||
| train_model(model, train_data_dir, validation_data_dir) |  | ||||||
| 
 |  | ||||||
| # 保存模型 |  | ||||||
| model.save('path/to/your/model.h5') |  | ||||||
| 
 |  | ||||||
| # 加载模型并预测 |  | ||||||
| model = load_model('path/to/your/model.h5') |  | ||||||
| 
 |  | ||||||
| def predict_traffic_sign(image): |  | ||||||
|     # 扩展维度以适配模型输入 |  | ||||||
|     image = np.expand_dims(image, axis=0) |  | ||||||
|     # 模型预测 |  | ||||||
|     predictions = model.predict(image) |  | ||||||
|     return predictions |  | ||||||
| 
 |  | ||||||
| # 示例 |  | ||||||
| predictions = predict_traffic_sign(preprocessed_image) |  | ||||||
| predicted_class = np.argmax(predictions, axis=1) |  | ||||||
					Loading…
					
					
				
		Reference in new issue