parent
0ad3a78d90
commit
5bb4574553
@ -0,0 +1,61 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model, Sequential
|
||||
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
def create_model():
|
||||
model = Sequential([
|
||||
Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)),
|
||||
MaxPooling2D((2, 2)),
|
||||
Conv2D(64, (3, 3), activation='relu'),
|
||||
MaxPooling2D((2, 2)),
|
||||
Flatten(),
|
||||
Dense(128, activation='relu'),
|
||||
Dropout(0.5),
|
||||
Dense(10, activation='softmax') # 假设有10个交通标志类别
|
||||
])
|
||||
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
def train_model(model, train_data_dir, validation_data_dir):
|
||||
train_datagen = ImageDataGenerator(rescale=1./255)
|
||||
val_datagen = ImageDataGenerator(rescale=1./255)
|
||||
|
||||
train_generator = train_datagen.flow_from_directory(
|
||||
train_data_dir,
|
||||
target_size=(640, 640),
|
||||
batch_size=32,
|
||||
class_mode='sparse'
|
||||
)
|
||||
|
||||
validation_generator = val_datagen.flow_from_directory(
|
||||
validation_data_dir,
|
||||
target_size=(640, 640),
|
||||
batch_size=32,
|
||||
class_mode='sparse'
|
||||
)
|
||||
|
||||
model.fit(train_generator, epochs=10, validation_data=validation_generator)
|
||||
|
||||
# 示例
|
||||
model = create_model()
|
||||
train_data_dir = 'path/to/train/data'
|
||||
validation_data_dir = 'path/to/validation/data'
|
||||
train_model(model, train_data_dir, validation_data_dir)
|
||||
|
||||
# 保存模型
|
||||
model.save('path/to/your/model.h5')
|
||||
|
||||
# 加载模型并预测
|
||||
model = load_model('path/to/your/model.h5')
|
||||
|
||||
def predict_traffic_sign(image):
|
||||
# 扩展维度以适配模型输入
|
||||
image = np.expand_dims(image, axis=0)
|
||||
# 模型预测
|
||||
predictions = model.predict(image)
|
||||
return predictions
|
||||
|
||||
# 示例
|
||||
predictions = predict_traffic_sign(preprocessed_image)
|
||||
predicted_class = np.argmax(predictions, axis=1)
|
Loading…
Reference in new issue