parent
5bb4574553
commit
cfd0ee0945
@ -1,61 +0,0 @@
|
|||||||
import tensorflow as tf
|
|
||||||
from tensorflow.keras.models import load_model, Sequential
|
|
||||||
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
|
|
||||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
||||||
|
|
||||||
def create_model():
|
|
||||||
model = Sequential([
|
|
||||||
Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)),
|
|
||||||
MaxPooling2D((2, 2)),
|
|
||||||
Conv2D(64, (3, 3), activation='relu'),
|
|
||||||
MaxPooling2D((2, 2)),
|
|
||||||
Flatten(),
|
|
||||||
Dense(128, activation='relu'),
|
|
||||||
Dropout(0.5),
|
|
||||||
Dense(10, activation='softmax') # 假设有10个交通标志类别
|
|
||||||
])
|
|
||||||
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
|
||||||
return model
|
|
||||||
|
|
||||||
def train_model(model, train_data_dir, validation_data_dir):
|
|
||||||
train_datagen = ImageDataGenerator(rescale=1./255)
|
|
||||||
val_datagen = ImageDataGenerator(rescale=1./255)
|
|
||||||
|
|
||||||
train_generator = train_datagen.flow_from_directory(
|
|
||||||
train_data_dir,
|
|
||||||
target_size=(640, 640),
|
|
||||||
batch_size=32,
|
|
||||||
class_mode='sparse'
|
|
||||||
)
|
|
||||||
|
|
||||||
validation_generator = val_datagen.flow_from_directory(
|
|
||||||
validation_data_dir,
|
|
||||||
target_size=(640, 640),
|
|
||||||
batch_size=32,
|
|
||||||
class_mode='sparse'
|
|
||||||
)
|
|
||||||
|
|
||||||
model.fit(train_generator, epochs=10, validation_data=validation_generator)
|
|
||||||
|
|
||||||
# 示例
|
|
||||||
model = create_model()
|
|
||||||
train_data_dir = 'path/to/train/data'
|
|
||||||
validation_data_dir = 'path/to/validation/data'
|
|
||||||
train_model(model, train_data_dir, validation_data_dir)
|
|
||||||
|
|
||||||
# 保存模型
|
|
||||||
model.save('path/to/your/model.h5')
|
|
||||||
|
|
||||||
# 加载模型并预测
|
|
||||||
model = load_model('path/to/your/model.h5')
|
|
||||||
|
|
||||||
def predict_traffic_sign(image):
|
|
||||||
# 扩展维度以适配模型输入
|
|
||||||
image = np.expand_dims(image, axis=0)
|
|
||||||
# 模型预测
|
|
||||||
predictions = model.predict(image)
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
# 示例
|
|
||||||
predictions = predict_traffic_sign(preprocessed_image)
|
|
||||||
predicted_class = np.argmax(predictions, axis=1)
|
|
Loading…
Reference in new issue