模型训练

main
lyd 2 months ago
parent 52dba77e65
commit d7f2df1b7d

@ -0,0 +1,55 @@
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import joblib
# 加载数据集使用你之前合并的feature_dataset00.pkl
def load_dataset(pkl_path):
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
# 训练模型
def train_and_save_model(dataset_path, model_save_path, scaler_save_path):
# 加载数据
X, y = load_dataset(dataset_path)
print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征")
# 划分训练集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 标准化
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)
# 训练SVM
svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42)
svm.fit(X_train_std, y_train)
# 评估
y_pred = svm.predict(X_test_std)
print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}")
# 保存模型和标准化器
joblib.dump(svm, model_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"模型已保存至:{model_save_path}")
print(f"标准化器已保存至:{scaler_save_path}")
if __name__ == "__main__":
# 替换为你的数据集路径
DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl"
# 模型保存路径与GUI代码中设置的路径一致
MODEL_PATH = "svm_model.pkl"
SCALER_PATH = "scaler.pkl"
train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH)
Loading…
Cancel
Save