parent
							
								
									989dd3a43c
								
							
						
					
					
						commit
						ede0508b33
					
				| @ -0,0 +1,61 @@ | |||||||
|  | import tensorflow as tf | ||||||
|  | from tensorflow.keras.models import load_model, Sequential | ||||||
|  | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout | ||||||
|  | from tensorflow.keras.preprocessing.image import ImageDataGenerator | ||||||
|  | 
 | ||||||
|  | def create_model(): | ||||||
|  |     model = Sequential([ | ||||||
|  |         Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)), | ||||||
|  |         MaxPooling2D((2, 2)), | ||||||
|  |         Conv2D(64, (3, 3), activation='relu'), | ||||||
|  |         MaxPooling2D((2, 2)), | ||||||
|  |         Flatten(), | ||||||
|  |         Dense(128, activation='relu'), | ||||||
|  |         Dropout(0.5), | ||||||
|  |         Dense(10, activation='softmax')  # 假设有10个交通标志类别 | ||||||
|  |     ]) | ||||||
|  |     model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | def train_model(model, train_data_dir, validation_data_dir): | ||||||
|  |     train_datagen = ImageDataGenerator(rescale=1./255) | ||||||
|  |     val_datagen = ImageDataGenerator(rescale=1./255) | ||||||
|  | 
 | ||||||
|  |     train_generator = train_datagen.flow_from_directory( | ||||||
|  |         train_data_dir, | ||||||
|  |         target_size=(640, 640), | ||||||
|  |         batch_size=32, | ||||||
|  |         class_mode='sparse' | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     validation_generator = val_datagen.flow_from_directory( | ||||||
|  |         validation_data_dir, | ||||||
|  |         target_size=(640, 640), | ||||||
|  |         batch_size=32, | ||||||
|  |         class_mode='sparse' | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     model.fit(train_generator, epochs=10, validation_data=validation_generator) | ||||||
|  | 
 | ||||||
|  | # 示例 | ||||||
|  | model = create_model() | ||||||
|  | train_data_dir = 'path/to/train/data' | ||||||
|  | validation_data_dir = 'path/to/validation/data' | ||||||
|  | train_model(model, train_data_dir, validation_data_dir) | ||||||
|  | 
 | ||||||
|  | # 保存模型 | ||||||
|  | model.save('path/to/your/model.h5') | ||||||
|  | 
 | ||||||
|  | # 加载模型并预测 | ||||||
|  | model = load_model('path/to/your/model.h5') | ||||||
|  | 
 | ||||||
|  | def predict_traffic_sign(image): | ||||||
|  |     # 扩展维度以适配模型输入 | ||||||
|  |     image = np.expand_dims(image, axis=0) | ||||||
|  |     # 模型预测 | ||||||
|  |     predictions = model.predict(image) | ||||||
|  |     return predictions | ||||||
|  | 
 | ||||||
|  | # 示例 | ||||||
|  | predictions = predict_traffic_sign(preprocessed_image) | ||||||
|  | predicted_class = np.argmax(predictions, axis=1) | ||||||
					Loading…
					
					
				
		Reference in new issue