parent
0ad3a78d90
commit
5bb4574553
@ -0,0 +1,61 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras.models import load_model, Sequential
|
||||||
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
def create_model():
|
||||||
|
model = Sequential([
|
||||||
|
Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)),
|
||||||
|
MaxPooling2D((2, 2)),
|
||||||
|
Conv2D(64, (3, 3), activation='relu'),
|
||||||
|
MaxPooling2D((2, 2)),
|
||||||
|
Flatten(),
|
||||||
|
Dense(128, activation='relu'),
|
||||||
|
Dropout(0.5),
|
||||||
|
Dense(10, activation='softmax') # 假设有10个交通标志类别
|
||||||
|
])
|
||||||
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
def train_model(model, train_data_dir, validation_data_dir):
|
||||||
|
train_datagen = ImageDataGenerator(rescale=1./255)
|
||||||
|
val_datagen = ImageDataGenerator(rescale=1./255)
|
||||||
|
|
||||||
|
train_generator = train_datagen.flow_from_directory(
|
||||||
|
train_data_dir,
|
||||||
|
target_size=(640, 640),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='sparse'
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_generator = val_datagen.flow_from_directory(
|
||||||
|
validation_data_dir,
|
||||||
|
target_size=(640, 640),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='sparse'
|
||||||
|
)
|
||||||
|
|
||||||
|
model.fit(train_generator, epochs=10, validation_data=validation_generator)
|
||||||
|
|
||||||
|
# 示例
|
||||||
|
model = create_model()
|
||||||
|
train_data_dir = 'path/to/train/data'
|
||||||
|
validation_data_dir = 'path/to/validation/data'
|
||||||
|
train_model(model, train_data_dir, validation_data_dir)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
model.save('path/to/your/model.h5')
|
||||||
|
|
||||||
|
# 加载模型并预测
|
||||||
|
model = load_model('path/to/your/model.h5')
|
||||||
|
|
||||||
|
def predict_traffic_sign(image):
|
||||||
|
# 扩展维度以适配模型输入
|
||||||
|
image = np.expand_dims(image, axis=0)
|
||||||
|
# 模型预测
|
||||||
|
predictions = model.predict(image)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
# 示例
|
||||||
|
predictions = predict_traffic_sign(preprocessed_image)
|
||||||
|
predicted_class = np.argmax(predictions, axis=1)
|
Loading…
Reference in new issue