You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.9 KiB
61 lines
1.9 KiB
import tensorflow as tf
|
|
from tensorflow.keras.models import load_model, Sequential
|
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
|
def create_model():
|
|
model = Sequential([
|
|
Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(64, (3, 3), activation='relu'),
|
|
MaxPooling2D((2, 2)),
|
|
Flatten(),
|
|
Dense(128, activation='relu'),
|
|
Dropout(0.5),
|
|
Dense(10, activation='softmax') # 假设有10个交通标志类别
|
|
])
|
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
|
return model
|
|
|
|
def train_model(model, train_data_dir, validation_data_dir):
|
|
train_datagen = ImageDataGenerator(rescale=1./255)
|
|
val_datagen = ImageDataGenerator(rescale=1./255)
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
train_data_dir,
|
|
target_size=(640, 640),
|
|
batch_size=32,
|
|
class_mode='sparse'
|
|
)
|
|
|
|
validation_generator = val_datagen.flow_from_directory(
|
|
validation_data_dir,
|
|
target_size=(640, 640),
|
|
batch_size=32,
|
|
class_mode='sparse'
|
|
)
|
|
|
|
model.fit(train_generator, epochs=10, validation_data=validation_generator)
|
|
|
|
# 示例
|
|
model = create_model()
|
|
train_data_dir = 'path/to/train/data'
|
|
validation_data_dir = 'path/to/validation/data'
|
|
train_model(model, train_data_dir, validation_data_dir)
|
|
|
|
# 保存模型
|
|
model.save('path/to/your/model.h5')
|
|
|
|
# 加载模型并预测
|
|
model = load_model('path/to/your/model.h5')
|
|
|
|
def predict_traffic_sign(image):
|
|
# 扩展维度以适配模型输入
|
|
image = np.expand_dims(image, axis=0)
|
|
# 模型预测
|
|
predictions = model.predict(image)
|
|
return predictions
|
|
|
|
# 示例
|
|
predictions = predict_traffic_sign(preprocessed_image)
|
|
predicted_class = np.argmax(predictions, axis=1) |