From 2cc5c55aee7c3f42e3c6526685b7403b5eba61a7 Mon Sep 17 00:00:00 2001 From: wh <2627521256@qq.com> Date: Mon, 10 Nov 2025 21:18:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E4=B8=8E?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E9=83=A8=E5=88=86=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 75%25%.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 75%25%.py diff --git a/75%25%.py b/75%25%.py new file mode 100644 index 0000000..cd68ac0 --- /dev/null +++ b/75%25%.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +使用 75% 训练 / 25% 测试 的方式评估 SVM(输出 ACC & AUC) +""" + +import pickle +import numpy as np +from pathlib import Path +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score, roc_auc_score + + +# ---------- 1. 数据路径 ---------- +PKL_PATH = Path(r"D:\Python\空心检测\pythonProject\feature_dataset.pkl") + +# ---------- 2. 读取特征 ---------- +def load_pkl_matrix(path: Path): + with open(path, "rb") as f: + data = pickle.load(f) + return data["matrix"], data["label"] + +X, y = load_pkl_matrix(PKL_PATH) +y = y.ravel() # shape (N,) + +# ---------- 3. 75% / 25% 拆分 ---------- +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.25, random_state=42, stratify=y, shuffle=True +) + +# ---------- 4. 标准化 + SVM ---------- +scaler = StandardScaler().fit(X_train) +X_train_std = scaler.transform(X_train) +X_test_std = scaler.transform(X_test) + +svm = SVC( + kernel="rbf", + C=10, + gamma="scale", + probability=True, + class_weight="balanced", + random_state=42, +) +svm.fit(X_train_std, y_train) + +# ---------- 5. 评估 ---------- +y_pred = svm.predict(X_test_std) +y_proba_pos = svm.predict_proba(X_test_std)[:, list(svm.classes_).index(1)] + +acc = accuracy_score(y_test, y_pred) +auc = roc_auc_score(y_test, y_proba_pos) + +print("\n========== 评估结果 ==========") +print(f"样本总数: {len(y)} | 训练: {len(y_train)} 测试: {len(y_test)}") +print(f"ACC = {acc:.4f}") +print(f"AUC = {auc:.4f}")