You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
2.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# This is a sample Python script.
# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
# # 任务:是否患有糖尿病(二分类)
# 模型LR
# 数据集:皮马人糖尿病数据集
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
import csv
def run(csvf):
# 切分数据集
df = pd.read_csv(csvf)
print(df)
isDetct = df.pop("isDetect")
target = df.pop("flag")
data = df.values
X = data
Y = target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=7)
# 数据标准化
ss = StandardScaler()
X_train_s = ss.fit_transform(X_train)
X_test_s = ss.transform(X_test)
# 输出下原数据的标准差和平均数
print(ss.scale_)
print(ss.mean_)
# LR模型预测
lr = LogisticRegression() #初始化LogisticRegression
lr.fit(X_train_s, Y_train) # 调用LogisticRegression中的fit函数训练模型参数
lr_pres = lr.predict(X_test_s) # 使用训练好的模型lr对X_test进行预测
print(lr.coef_)
print(lr.intercept_)
print('准确率:', accuracy_score(Y_test, lr_pres))
print('精确率:', precision_score(Y_test, lr_pres))
print('召回率:', recall_score(Y_test, lr_pres))
# 混淆矩阵热点图
labels = [0, 1]
cm = confusion_matrix(Y_test, lr_pres, labels=labels)
sns.heatmap(cm, annot=True, annot_kws={'size': 20,'weight': 'bold', 'color': 'blue'})
#plt.rc('font', family='Arial Unicode MS', size=14)
plt.rcParams['font.sans-serif'] = ['KaiTi']
plt.title('混淆矩阵', fontsize=20)
plt.xlabel('Actual', fontsize=14)
plt.ylabel('Predict', fontsize=14)
plt.show()
# ROC曲线和AUC
lr_pres_proba = lr.predict_proba(X_test)[::, 1]
fpr, tpr, thresholds = roc_curve(Y_test, lr_pres_proba)
auc = roc_auc_score(Y_test, lr_pres_proba)
plt.figure(figsize=(5, 3), dpi=100)
plt.plot(fpr, tpr, label="AUC={:.2f}" .format(auc))
plt.legend(loc=4, fontsize=10)
plt.title('信号干扰检测LR分类的ROC和AUC', fontsize=10)
plt.xlabel('FPR', fontsize=14)
plt.ylabel('TPR', fontsize=14)
plt.show()
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
csvf = '20221003-17_47-00-model.csv'
run(csvf)
# See PyCharm help at https://www.jetbrains.com/help/pycharm/