diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..33b6bde --- /dev/null +++ b/train_model.py @@ -0,0 +1,55 @@ +import pickle +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score +import joblib + + +# 加载数据集(使用你之前合并的feature_dataset00.pkl) +def load_dataset(pkl_path): + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data['matrix'], data['label'] + + +# 训练模型 +def train_and_save_model(dataset_path, model_save_path, scaler_save_path): + # 加载数据 + X, y = load_dataset(dataset_path) + print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征") + + # 划分训练集 + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y + ) + + # 标准化 + scaler = StandardScaler() + X_train_std = scaler.fit_transform(X_train) + X_test_std = scaler.transform(X_test) + + # 训练SVM + svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42) + svm.fit(X_train_std, y_train) + + # 评估 + y_pred = svm.predict(X_test_std) + print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}") + + # 保存模型和标准化器 + joblib.dump(svm, model_save_path) + joblib.dump(scaler, scaler_save_path) + print(f"模型已保存至:{model_save_path}") + print(f"标准化器已保存至:{scaler_save_path}") + + +if __name__ == "__main__": + # 替换为你的数据集路径 + DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl" + # 模型保存路径(与GUI代码中设置的路径一致) + MODEL_PATH = "svm_model.pkl" + SCALER_PATH = "scaler.pkl" + + train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH)