|
|
# This is a sample Python script.
|
|
|
|
|
|
# Press Shift+F10 to execute it or replace it with your code.
|
|
|
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
|
|
|
|
|
|
|
|
# # 任务:是否患有糖尿病(二分类)
|
|
|
# 模型:LR
|
|
|
# 数据集:皮马人糖尿病数据集
|
|
|
import pandas as pd
|
|
|
import matplotlib.pyplot as plt
|
|
|
import seaborn as sns
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
from sklearn.metrics import precision_score
|
|
|
from sklearn.metrics import recall_score
|
|
|
from sklearn.metrics import classification_report
|
|
|
from sklearn.metrics import confusion_matrix
|
|
|
from sklearn.metrics import roc_curve
|
|
|
from sklearn.metrics import roc_auc_score
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
import csv
|
|
|
|
|
|
|
|
|
def run(csvf):
|
|
|
# 切分数据集
|
|
|
df = pd.read_csv(csvf)
|
|
|
print(df)
|
|
|
isDetct = df.pop("isDetect")
|
|
|
target = df.pop("flag")
|
|
|
data = df.values
|
|
|
X = data
|
|
|
Y = target
|
|
|
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=7)
|
|
|
|
|
|
# 数据标准化
|
|
|
ss = StandardScaler()
|
|
|
X_train_s = ss.fit_transform(X_train)
|
|
|
X_test_s = ss.transform(X_test)
|
|
|
# 输出下原数据的标准差和平均数
|
|
|
print(ss.scale_)
|
|
|
print(ss.mean_)
|
|
|
|
|
|
# LR模型预测
|
|
|
lr = LogisticRegression() #初始化LogisticRegression
|
|
|
lr.fit(X_train_s, Y_train) # 调用LogisticRegression中的fit函数训练模型参数
|
|
|
lr_pres = lr.predict(X_test_s) # 使用训练好的模型lr对X_test进行预测
|
|
|
print(lr.coef_)
|
|
|
print(lr.intercept_)
|
|
|
print('准确率:', accuracy_score(Y_test, lr_pres))
|
|
|
print('精确率:', precision_score(Y_test, lr_pres))
|
|
|
print('召回率:', recall_score(Y_test, lr_pres))
|
|
|
|
|
|
# 混淆矩阵热点图
|
|
|
labels = [0, 1]
|
|
|
cm = confusion_matrix(Y_test, lr_pres, labels=labels)
|
|
|
sns.heatmap(cm, annot=True, annot_kws={'size': 20,'weight': 'bold', 'color': 'blue'})
|
|
|
#plt.rc('font', family='Arial Unicode MS', size=14)
|
|
|
plt.rcParams['font.sans-serif'] = ['KaiTi']
|
|
|
plt.title('混淆矩阵', fontsize=20)
|
|
|
plt.xlabel('Actual', fontsize=14)
|
|
|
plt.ylabel('Predict', fontsize=14)
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
# ROC曲线和AUC
|
|
|
lr_pres_proba = lr.predict_proba(X_test)[::, 1]
|
|
|
fpr, tpr, thresholds = roc_curve(Y_test, lr_pres_proba)
|
|
|
auc = roc_auc_score(Y_test, lr_pres_proba)
|
|
|
plt.figure(figsize=(5, 3), dpi=100)
|
|
|
plt.plot(fpr, tpr, label="AUC={:.2f}" .format(auc))
|
|
|
plt.legend(loc=4, fontsize=10)
|
|
|
plt.title('信号干扰检测LR分类的ROC和AUC', fontsize=10)
|
|
|
plt.xlabel('FPR', fontsize=14)
|
|
|
plt.ylabel('TPR', fontsize=14)
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
# Press the green button in the gutter to run the script.
|
|
|
if __name__ == '__main__':
|
|
|
csvf = '20221003-17_47-00-model.csv'
|
|
|
run(csvf)
|
|
|
|
|
|
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|