You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-02/01src/train_model.py

56 lines
1.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import joblib
# 加载数据集使用你之前合并的feature_dataset00.pkl
def load_dataset(pkl_path):
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
# 训练模型
def train_and_save_model(dataset_path, model_save_path, scaler_save_path):
# 加载数据
X, y = load_dataset(dataset_path)
print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征")
# 划分训练集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 标准化
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)
# 训练SVM
svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42)
svm.fit(X_train_std, y_train)
# 评估
y_pred = svm.predict(X_test_std)
print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}")
# 保存模型和标准化器
joblib.dump(svm, model_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"模型已保存至:{model_save_path}")
print(f"标准化器已保存至:{scaler_save_path}")
if __name__ == "__main__":
# 替换为你的数据集路径
DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl"
# 模型保存路径与GUI代码中设置的路径一致
MODEL_PATH = "svm_model.pkl"
SCALER_PATH = "scaler.pkl"
train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH)