# -*- coding:utf-8 -*- import argparse import os.path import matplotlib matplotlib.use('TkAgg') # 例如使用TkAgg后端 import matplotlib.pyplot as plt import pandas as pd import seaborn as sns def set_outputfolder(output_path): if not os.path.exists(output_path): os.makedirs(output_path) def basic_show(train_1): print('----------------------------------------------------') print("数据分布情况:") print(train_1.describe()) print('----------------------------------------------------') print("数据类型观察:") print(train_1.info()) print('----------------------------------------------------') print("数据内容:") print(train_1.head()) print('----------------------------------------------------') print("原始数据缺失情况:") print(train_1.isnull().sum()) def basic_show1(df, path): # 性别对于生还的影响 sns.barplot(x="Sex", y="Survived", data=df) plt.savefig(path + '/sex.jpg') plt.show() # 船舱等级对生还的影响 sns.barplot(x="Pclass", y="Survived", data=df) plt.savefig(path + '/Pclass.jpg') plt.show() # 配偶及兄弟姐妹数适中的乘客幸存率更高 sns.barplot(x="SibSp", y="Survived", data=df) plt.savefig(path + '/SibSp.jpg') plt.show() # 父母与子女数适中的乘客幸存率更高 sns.barplot(x="Parch", y="Survived", data=df) plt.savefig(path + '/Parch.jpg') plt.show() # Embarked登港港口与生存情况的分析 sns.countplot(x='Embarked', hue='Survived', data=df, fill=True) plt.savefig(path + '/Embarked.jpg') plt.show() def p_class_sex(train_1, path): plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False pd.set_option('display.max_columns', None) pd.set_option('display.max_rows', None) pd.set_option('max_colwidth', 100) pd.set_option('display.width', 1000) fig = plt.figure() #图形对象(Figure),并将其保存在变量 fig 中 fig.set(alpha=0.5) fig = plt.figure() fig.set(alpha=0.65) # 设置图像透明度,无所谓 plt.title(u"根据舱等级和性别的获救情况") ax1 = fig.add_subplot(141)#1*4 train_1.Survived[train_1.Sex == 'female'][train_1.Pclass != 3].value_counts().plot(kind='bar',label="female highclass",color='#FA2479') ax1.set_xticklabels([u"获救", u"未获救"], rotation=0)#x标签 ax1.legend([u"女性/高级舱"], loc='best') ax2 = fig.add_subplot(142, sharey=ax1) train_1.Survived[train_1.Sex == 'female'][train_1.Pclass == 3].value_counts().plot(kind='bar',label='female, low class',color='pink') ax2.set_xticklabels([u"未获救", u"获救"], rotation=0) plt.legend([u"女性/低级舱"], loc='best') ax3 = fig.add_subplot(143, sharey=ax1) train_1.Survived[train_1.Sex == 'male'][train_1.Pclass != 3].value_counts().plot(kind='bar',label='male, high class',color='lightblue') ax3.set_xticklabels([u"未获救", u"获救"], rotation=0) plt.legend([u"男性/高级舱"], loc='best') ax4 = fig.add_subplot(144, sharey=ax1) train_1.Survived[train_1.Sex == 'male'][train_1.Pclass == 3].value_counts().plot(kind='bar', label='male low class',color='steelblue') ax4.set_xticklabels([u"未获救", u"获救"], rotation=0) plt.legend([u"男性/低级舱"], loc='best') plt.savefig(path + '/p_class_set.jpg') plt.show() def age(train_1, path): plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 pd.set_option('display.max_columns', None) pd.set_option('display.max_rows', None) pd.set_option('max_colwidth', 100) pd.set_option('display.width', 1000) fig = plt.figure(figsize=(8, 5)) fig.suptitle(u"各等级的乘客年龄分布") ax = fig.add_subplot(111) for pclass in [1, 2, 3]: sns.kdeplot(train_1[train_1['Pclass'] == pclass]['Age'], ax=ax, label=f'Pclass {pclass}') ax.set_xlabel(u"年龄") ax.set_ylabel(u"密度") plt.legend(loc='best') plt.savefig(path + '/age_distribution.jpg') plt.show() def carbin(full_1, path): full_1['Has_Cabin'] = full_1['Cabin'].apply(lambda x: 0 if pd.isna(x) else 1) sns.barplot(x='Has_Cabin', y='Survived', data=full_1) plt.title('Survival Rate') plt.savefig(path + '/survival_rate.jpg') plt.show() cabin = pd.crosstab(full_1['Has_Cabin'], full_1['Survived']) cabin.rename(index={0: 'no cabin', 1: 'have cabin'}, columns={0: 'Dead', 1: 'Survived'}, inplace=True) cabin.plot.bar(figsize=(8, 5)) plt.xticks(rotation=0, size='large') plt.title('Survived Count') plt.xlabel('') plt.legend() plt.savefig(path + '/carbin.jpg') plt.show() def age_scatter_plot(train_1, path): plt.scatter(train_1['Age'], train_1['Survived']) plt.ylabel(u"获救情况 (1为获救)") plt.xlabel(u"年龄") plt.title(u"年龄与获救情况散点图") plt.savefig(path + '/age_scatter_plot.jpg') plt.show() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--dest_path", default='E:/python1/123', help="files save path") parser.add_argument('--train', default='train.csv', help="train data") parser.add_argument('--test', default='test.csv', help="test data") args = parser.parse_args() output_path = args.dest_path + '/data_output/' set_outputfolder(output_path) # 拿到测试集和训练集 train = pd.read_csv(args.train) test = pd.read_csv(args.test) full = pd.concat([train, test], ignore_index=True) basic_show(train) basic_show1(train, output_path) age(train, output_path) p_class_sex(train, output_path) carbin(full, output_path) age_scatter_plot(train, output_path)