|
|
|
|
@ -0,0 +1,166 @@
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
交叉验证(最小改动版)——与原结构一致,仅在第6段训练后新增模型与scaler导出
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import pickle
|
|
|
|
|
import numpy as np
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
# 新增/补全的 import
|
|
|
|
|
import joblib
|
|
|
|
|
from scipy.stats import randint, loguniform, norm
|
|
|
|
|
from sklearn.svm import SVC
|
|
|
|
|
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV, cross_val_score
|
|
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
|
from sklearn.feature_selection import SelectKBest, mutual_info_classif
|
|
|
|
|
from sklearn.pipeline import Pipeline
|
|
|
|
|
|
|
|
|
|
# ---------- 1. 数据路径 ----------
|
|
|
|
|
BASE_DIR = Path(r'D:\SummerSchool\mat_cv\mat_cv')
|
|
|
|
|
TRAIN_PKL = BASE_DIR / 'cv10_train.pkl'
|
|
|
|
|
TEST_FILES = [BASE_DIR / 'cv10_test.pkl'] # 也可放多个测试集 pkl 文件
|
|
|
|
|
|
|
|
|
|
# ---------- 2. 工具 ----------
|
|
|
|
|
def load_pkl_matrix(path: Path):
|
|
|
|
|
with open(path, 'rb') as f:
|
|
|
|
|
data = pickle.load(f)
|
|
|
|
|
return data['matrix'], data.get('label')
|
|
|
|
|
|
|
|
|
|
# ---------- 3. 读取训练集 ----------
|
|
|
|
|
X_train, y_train = load_pkl_matrix(TRAIN_PKL)
|
|
|
|
|
if y_train is None:
|
|
|
|
|
raise ValueError('训练集缺少 label 字段')
|
|
|
|
|
y_train = y_train.ravel()
|
|
|
|
|
# {0,1} → {-1,+1}
|
|
|
|
|
y_train_signed = np.where(y_train == 0, -1, 1)
|
|
|
|
|
|
|
|
|
|
# ---------- 4. 标准化 ----------
|
|
|
|
|
scaler = StandardScaler().fit(X_train)
|
|
|
|
|
X_train_std = scaler.transform(X_train)
|
|
|
|
|
n_features = X_train_std.shape[1]
|
|
|
|
|
|
|
|
|
|
# ---------- 5. RandomizedSearchCV 搜索 ----------
|
|
|
|
|
pipe = Pipeline([
|
|
|
|
|
('sel', SelectKBest(mutual_info_classif)),
|
|
|
|
|
('svm', SVC(kernel='rbf', class_weight='balanced', probability=True))
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
param_dist = {
|
|
|
|
|
'sel__k': randint(1, n_features + 1),
|
|
|
|
|
'svm__C': loguniform(1e-3, 1e3),
|
|
|
|
|
'svm__gamma': loguniform(1e-6, 1e1)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
|
|
|
|
search = RandomizedSearchCV(
|
|
|
|
|
pipe,
|
|
|
|
|
param_distributions=param_dist,
|
|
|
|
|
n_iter=60, # 搜索次数
|
|
|
|
|
scoring='roc_auc',
|
|
|
|
|
cv=cv_inner,
|
|
|
|
|
n_jobs=-1,
|
|
|
|
|
random_state=42,
|
|
|
|
|
verbose=1
|
|
|
|
|
)
|
|
|
|
|
search.fit(X_train_std, y_train_signed)
|
|
|
|
|
best_params = search.best_params_
|
|
|
|
|
print("\n▶ RandomizedSearch 最佳参数:", best_params)
|
|
|
|
|
print(f" 内层 5-折 AUC ≈ {search.best_score_:.4f}")
|
|
|
|
|
|
|
|
|
|
# ---------- 6. 训练最终流水线 ----------
|
|
|
|
|
final_model = search.best_estimator_
|
|
|
|
|
final_model.fit(X_train_std, y_train_signed)
|
|
|
|
|
|
|
|
|
|
# ---------- 6.5 新增:导出模型与标准化器(供 GUI 使用) ----------
|
|
|
|
|
# 输出到 BASE_DIR 下,也可按需改路径
|
|
|
|
|
model_out = BASE_DIR / 'svm_model.pkl'
|
|
|
|
|
scaler_out = BASE_DIR / 'scaler.pkl'
|
|
|
|
|
joblib.dump(final_model, model_out)
|
|
|
|
|
joblib.dump(scaler, scaler_out)
|
|
|
|
|
print(f"\n✅ 已导出模型与标尺:\n 模型: {model_out}\n 标尺: {scaler_out}\n has_predict_proba: {hasattr(final_model, 'predict_proba')}")
|
|
|
|
|
|
|
|
|
|
# ---------- 7. 外层 5-折交叉验证 ----------
|
|
|
|
|
cv_outer = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
|
|
|
|
cv_auc = cross_val_score(final_model, X_train_std, y_train_signed,
|
|
|
|
|
cv=cv_outer, scoring='roc_auc', n_jobs=-1)
|
|
|
|
|
cv_acc = cross_val_score(final_model, X_train_std, y_train_signed,
|
|
|
|
|
cv=cv_outer, scoring='accuracy', n_jobs=-1)
|
|
|
|
|
|
|
|
|
|
print('\n========== 外层 5-折交叉验证 ==========')
|
|
|
|
|
print(f'AUC = {cv_auc.mean():.4f} ± {cv_auc.std():.4f}')
|
|
|
|
|
print(f'ACC = {cv_acc.mean():.4f} ± {cv_acc.std():.4f}')
|
|
|
|
|
|
|
|
|
|
# ---------- 8. 推断 ----------
|
|
|
|
|
THRESHOLD = 0.5
|
|
|
|
|
Z = norm.ppf(0.975)
|
|
|
|
|
infer_results = []
|
|
|
|
|
print('\n========== 推断结果 ==========')
|
|
|
|
|
|
|
|
|
|
for pkl_path in TEST_FILES:
|
|
|
|
|
X_test, _ = load_pkl_matrix(pkl_path)
|
|
|
|
|
X_test_std = scaler.transform(X_test)
|
|
|
|
|
pred_signed = final_model.predict(X_test_std)
|
|
|
|
|
proba_pos = final_model.predict_proba(X_test_std)[:, 1]
|
|
|
|
|
pred_label = np.where(pred_signed == -1, 0, 1)
|
|
|
|
|
|
|
|
|
|
mean_p = proba_pos.mean()
|
|
|
|
|
sem_p = proba_pos.std(ddof=1) / np.sqrt(len(proba_pos)) if len(proba_pos) > 1 else 0.0
|
|
|
|
|
ci_low, ci_high = mean_p - Z * sem_p, mean_p + Z * sem_p
|
|
|
|
|
file_label = int(mean_p >= THRESHOLD)
|
|
|
|
|
|
|
|
|
|
print(f'\n▶ 文件: {pkl_path.name} (样本 {len(pred_label)})')
|
|
|
|
|
for i, (lbl, prob) in enumerate(zip(pred_label, proba_pos), 1):
|
|
|
|
|
print(f' Sample {i:02d}: pred={lbl} prob(1)={prob:.4f}')
|
|
|
|
|
print(' ---- 文件级融合 ----')
|
|
|
|
|
print(f' mean_prob(1) = {mean_p:.4f} (95% CI {ci_low:.4f} ~ {ci_high:.4f})')
|
|
|
|
|
print(f' Final label = {file_label} (阈值 {THRESHOLD})')
|
|
|
|
|
|
|
|
|
|
infer_results.append(dict(
|
|
|
|
|
file=pkl_path.name,
|
|
|
|
|
pred=pred_label.tolist(),
|
|
|
|
|
prob=proba_pos.tolist(),
|
|
|
|
|
mean_prob=float(mean_p),
|
|
|
|
|
ci_low=float(ci_low),
|
|
|
|
|
ci_high=float(ci_high),
|
|
|
|
|
final_label=int(file_label)
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
# 打印测试文件的原始标签(若有)
|
|
|
|
|
try:
|
|
|
|
|
print("TEST_FILES 标签:", load_pkl_matrix(TEST_FILES[0])[1])
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# ---------- 9. 保存 & 可视化 ----------
|
|
|
|
|
out_pkl = BASE_DIR / 'infer_results.pkl'
|
|
|
|
|
with open(out_pkl, 'wb') as f:
|
|
|
|
|
pickle.dump(infer_results, f)
|
|
|
|
|
print(f'\n所有文件结果已保存到: {out_pkl}')
|
|
|
|
|
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
|
|
|
|
labels = [r['file'] for r in infer_results]
|
|
|
|
|
means = [r['mean_prob'] for r in infer_results]
|
|
|
|
|
yerr = [(r['mean_prob'] - r['ci_low'], r['ci_high'] - r['mean_prob'])
|
|
|
|
|
for r in infer_results]
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
|
|
|
ax.bar(range(len(means)), means,
|
|
|
|
|
yerr=np.array(yerr).T, capsize=5, alpha=0.8)
|
|
|
|
|
ax.axhline(THRESHOLD, color='red', ls='--', label=f'阈值 {THRESHOLD}')
|
|
|
|
|
ax.set_xticks(range(len(labels)))
|
|
|
|
|
ax.set_xticklabels(labels, rotation=15)
|
|
|
|
|
ax.set_ylim(0, 1)
|
|
|
|
|
ax.set_ylabel('mean_prob(空心=0)')
|
|
|
|
|
ax.set_title('文件级空心概率 (±95% CI)')
|
|
|
|
|
ax.legend()
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
desktop = Path.home() / 'Desktop'
|
|
|
|
|
save_path = desktop / 'infer_summary.png'
|
|
|
|
|
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
|
|
print(f'可视化图已保存至: {save_path}')
|
|
|
|
|
plt.show()
|