From dcd0c28e0b27d0b16b3b786d3858ccbdd7a690c6 Mon Sep 17 00:00:00 2001 From: m56ftjlkq <189777132@qq.com> Date: Fri, 20 Dec 2024 21:06:28 +0800 Subject: [PATCH] =?UTF-8?q?Add=20=E7=83=AD=E5=8A=9B=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 热力图 | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 热力图 diff --git a/热力图 b/热力图 new file mode 100644 index 0000000..6f5bef4 --- /dev/null +++ b/热力图 @@ -0,0 +1,54 @@ +import pandas as pd +import numpy as np +from plotly import __version__ +print (__version__) +from plotly.offline import init_notebook_mode +init_notebook_mode(connected=True) +from plotly.graph_objs import Figure, Layout, Pie,Bar +import matplotlib.pyplot as plt +import seaborn as sns +import plotly.graph_objects as go +import plotly.io as pio + +colors = ['#e43620', '#f16d30','#d99a6c','#fed976', '#b3cb95', '#41bfb3','#229bac', '#256894'] +data = pd.read_csv('/visualization//HR_comma_sep.csv') +df = data.rename(columns = {"sales":"department","promotion_last_5years":"promotion","Work_accident":"work_accident"}) +df['department'] = df['department'].astype('category')#, categories=cat.categories) +df['salary'] = df['salary'].astype('category')#, categories=cat.categories) +salary_dict = dict(enumerate(df['salary'].cat.categories)) +department_dict = dict(enumerate(df['department'].cat.categories)) +#print( dict(enumerate(df['salary'].cat.categories))) +for feature in df.columns: + if str(df[feature].dtype) == 'category': + df[feature] = df[feature].cat.codes + # df[feature] = pd.Categorical(df[feature]).codes + df[feature] = df[feature].astype("int64") # 设置数据类型为int64 +cols = df.columns +cols = list(cols[:6]) + list(cols[7:]) + [cols[6]] +print('Reordered Columns:',cols) +df = df[cols] +left_summary = df.groupby(by=['left']).mean() +corr = df.corr() +#print(corr) # pearson相关系数 +mask = np.zeros_like(corr) +#print(mask) +mask[np.tril_indices_from(mask)]=True +with sns.axes_style("white"):#seaborn设置坐标风格 + sns.set(rc={'figure.figsize':(11,7)})#宽度高度 + ax = sns.heatmap(corr, + xticklabels=True, yticklabels=True, #表示在热力图的 x 轴和 y 轴上显示对应的标签数据框中各列的名称,对应相关系数矩阵的行和列索引 + cmap='RdBu', # 颜色红蓝 + mask=mask, # 使用掩码只绘制矩阵的一部分 + fmt='.3f', # 相关系数格式设置保留3位 + annot=True, # 方格内写入数据 + linewidths=.5, # 热力图矩阵之间的间隔大小设置了热力图中每个方格之间的间隔线条的宽度为 0.5 + vmax=.4, # 指定了热力图颜色映射中颜色所对应的最大值,突出显示相关系数绝对值在 0 到 0.4 这个区间内的变化情况 + square = True #每个方格呈现正方形形状 + # center = 0 + ) +plt.title("Correlation") +label_x = ax.get_xticklabels() +plt.setp(label_x,rotation=45, horizontalalignment='right') +#plt.show() + +