#!/usr/bin/env python # -*- coding: utf-8 -*- """ 基于交叉验证代码的训练脚本 """ import pickle import numpy as np from pathlib import Path from sklearn.svm import SVC from sklearn.feature_selection import SelectKBest, mutual_info_classif from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from scipy.stats import randint, loguniform import joblib def load_dataset(pkl_path): """加载数据集""" with open(pkl_path, 'rb') as f: data = pickle.load(f) return data['matrix'], data['label'] def train_cross_validated_model(dataset_path, pipeline_save_path, scaler_save_path): # 加载数据 X_train, y_train = load_dataset(dataset_path) print(f"加载数据集:{X_train.shape[0]}个样本,{X_train.shape[1]}维特征") # 将标签转换为-1/1格式(与交叉验证代码一致) y_train_signed = np.where(y_train == 0, -1, 1) # 创建流水线(包含特征选择和SVM) pipe = Pipeline([ ('sel', SelectKBest(mutual_info_classif)), ('svm', SVC(kernel='rbf', class_weight='balanced', probability=True)) ]) # 参数分布(与交叉验证代码一致) n_features = X_train.shape[1] param_dist = { 'sel__k': randint(1, n_features + 1), 'svm__C': loguniform(1e-3, 1e3), 'svm__gamma': loguniform(1e-6, 1e1) } # 随机搜索 cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) search = RandomizedSearchCV( pipe, param_distributions=param_dist, n_iter=30, # 减少迭代次数以加快训练 scoring='roc_auc', cv=cv_inner, n_jobs=-1, random_state=42, verbose=1 ) print("开始随机搜索优化...") search.fit(X_train, y_train_signed) best_params = search.best_params_ print(f"\n最佳参数: {best_params}") print(f"最佳交叉验证AUC: {search.best_score_:.4f}") # 训练最终模型 final_model = search.best_estimator_ final_model.fit(X_train, y_train_signed) # 单独保存标准化器(用于GUI中的特征标准化) scaler = StandardScaler().fit(X_train) # 保存模型 joblib.dump(final_model, pipeline_save_path) joblib.dump(scaler, scaler_save_path) print(f"流水线模型已保存至: {pipeline_save_path}") print(f"标准化器已保存至: {scaler_save_path}") return final_model, scaler if __name__ == "__main__": # 使用你的训练集路径 DATASET_PATH = r"D:\Python\空心检测\pythonProject\feature_dataset.pkl" PIPELINE_PATH = "pipeline_model.pkl" # 与GUI中设置的路径一致 SCALER_PATH = "scaler.pkl" train_cross_validated_model(DATASET_PATH, PIPELINE_PATH, SCALER_PATH)