Add lyh部分

main
p6tnmhi3f 1 year ago
parent 0ad3a78d90
commit 5bb4574553

@ -0,0 +1,61 @@
import tensorflow as tf
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def create_model():
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(640, 640, 3)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax') # 假设有10个交通标志类别
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
def train_model(model, train_data_dir, validation_data_dir):
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(640, 640),
batch_size=32,
class_mode='sparse'
)
validation_generator = val_datagen.flow_from_directory(
validation_data_dir,
target_size=(640, 640),
batch_size=32,
class_mode='sparse'
)
model.fit(train_generator, epochs=10, validation_data=validation_generator)
# 示例
model = create_model()
train_data_dir = 'path/to/train/data'
validation_data_dir = 'path/to/validation/data'
train_model(model, train_data_dir, validation_data_dir)
# 保存模型
model.save('path/to/your/model.h5')
# 加载模型并预测
model = load_model('path/to/your/model.h5')
def predict_traffic_sign(image):
# 扩展维度以适配模型输入
image = np.expand_dims(image, axis=0)
# 模型预测
predictions = model.predict(image)
return predictions
# 示例
predictions = predict_traffic_sign(preprocessed_image)
predicted_class = np.argmax(predictions, axis=1)
Loading…
Cancel
Save