You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-02/01src/train_cross_validated_model.py

91 lines
2.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基于交叉验证代码的训练脚本
"""
import pickle
import numpy as np
from pathlib import Path
from sklearn.svm import SVC
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from scipy.stats import randint, loguniform
import joblib
def load_dataset(pkl_path):
"""加载数据集"""
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
def train_cross_validated_model(dataset_path, pipeline_save_path, scaler_save_path):
# 加载数据
X_train, y_train = load_dataset(dataset_path)
print(f"加载数据集:{X_train.shape[0]}个样本,{X_train.shape[1]}维特征")
# 将标签转换为-1/1格式与交叉验证代码一致
y_train_signed = np.where(y_train == 0, -1, 1)
# 创建流水线包含特征选择和SVM
pipe = Pipeline([
('sel', SelectKBest(mutual_info_classif)),
('svm', SVC(kernel='rbf', class_weight='balanced', probability=True))
])
# 参数分布(与交叉验证代码一致)
n_features = X_train.shape[1]
param_dist = {
'sel__k': randint(1, n_features + 1),
'svm__C': loguniform(1e-3, 1e3),
'svm__gamma': loguniform(1e-6, 1e1)
}
# 随机搜索
cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
search = RandomizedSearchCV(
pipe,
param_distributions=param_dist,
n_iter=30, # 减少迭代次数以加快训练
scoring='roc_auc',
cv=cv_inner,
n_jobs=-1,
random_state=42,
verbose=1
)
print("开始随机搜索优化...")
search.fit(X_train, y_train_signed)
best_params = search.best_params_
print(f"\n最佳参数: {best_params}")
print(f"最佳交叉验证AUC: {search.best_score_:.4f}")
# 训练最终模型
final_model = search.best_estimator_
final_model.fit(X_train, y_train_signed)
# 单独保存标准化器用于GUI中的特征标准化
scaler = StandardScaler().fit(X_train)
# 保存模型
joblib.dump(final_model, pipeline_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"流水线模型已保存至: {pipeline_save_path}")
print(f"标准化器已保存至: {scaler_save_path}")
return final_model, scaler
if __name__ == "__main__":
# 使用你的训练集路径
DATASET_PATH = r"D:\Python\空心检测\pythonProject\feature_dataset.pkl"
PIPELINE_PATH = "pipeline_model.pkl" # 与GUI中设置的路径一致
SCALER_PATH = "scaler.pkl"
train_cross_validated_model(DATASET_PATH, PIPELINE_PATH, SCALER_PATH)