|
|
# -*- coding:utf-8 -*-
|
|
|
import argparse
|
|
|
import os.path
|
|
|
import matplotlib
|
|
|
matplotlib.use('TkAgg') # 例如使用TkAgg后端
|
|
|
import matplotlib.pyplot as plt
|
|
|
import pandas as pd
|
|
|
import seaborn as sns
|
|
|
|
|
|
def set_outputfolder(output_path):
|
|
|
if not os.path.exists(output_path):
|
|
|
os.makedirs(output_path)
|
|
|
|
|
|
def basic_show(train_1):
|
|
|
print('----------------------------------------------------')
|
|
|
print("数据分布情况:")
|
|
|
print(train_1.describe())
|
|
|
print('----------------------------------------------------')
|
|
|
print("数据类型观察:")
|
|
|
print(train_1.info())
|
|
|
print('----------------------------------------------------')
|
|
|
print("数据内容:")
|
|
|
print(train_1.head())
|
|
|
print('----------------------------------------------------')
|
|
|
print("原始数据缺失情况:")
|
|
|
print(train_1.isnull().sum())
|
|
|
|
|
|
def basic_show1(df, path):
|
|
|
# 性别对于生还的影响
|
|
|
sns.barplot(x="Sex", y="Survived", data=df)
|
|
|
plt.savefig(path + '/sex.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
# 船舱等级对生还的影响
|
|
|
sns.barplot(x="Pclass", y="Survived", data=df)
|
|
|
plt.savefig(path + '/Pclass.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
# 配偶及兄弟姐妹数适中的乘客幸存率更高
|
|
|
sns.barplot(x="SibSp", y="Survived", data=df)
|
|
|
plt.savefig(path + '/SibSp.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
# 父母与子女数适中的乘客幸存率更高
|
|
|
sns.barplot(x="Parch", y="Survived", data=df)
|
|
|
plt.savefig(path + '/Parch.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
# Embarked登港港口与生存情况的分析
|
|
|
sns.countplot(x='Embarked', hue='Survived', data=df, fill=True)
|
|
|
plt.savefig(path + '/Embarked.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
def p_class_sex(train_1, path):
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
pd.set_option('display.max_columns', None)
|
|
|
pd.set_option('display.max_rows', None)
|
|
|
pd.set_option('max_colwidth', 100)
|
|
|
pd.set_option('display.width', 1000)
|
|
|
fig = plt.figure() #图形对象(Figure),并将其保存在变量 fig 中
|
|
|
fig.set(alpha=0.5)
|
|
|
|
|
|
fig = plt.figure()
|
|
|
fig.set(alpha=0.65) # 设置图像透明度,无所谓
|
|
|
plt.title(u"根据舱等级和性别的获救情况")
|
|
|
|
|
|
ax1 = fig.add_subplot(141)#1*4
|
|
|
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass != 3].value_counts().plot(kind='bar',label="female highclass",color='#FA2479')
|
|
|
ax1.set_xticklabels([u"获救", u"未获救"], rotation=0)#x标签
|
|
|
ax1.legend([u"女性/高级舱"], loc='best')
|
|
|
|
|
|
ax2 = fig.add_subplot(142, sharey=ax1)
|
|
|
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass == 3].value_counts().plot(kind='bar',label='female, low class',color='pink')
|
|
|
ax2.set_xticklabels([u"未获救", u"获救"], rotation=0)
|
|
|
plt.legend([u"女性/低级舱"], loc='best')
|
|
|
|
|
|
ax3 = fig.add_subplot(143, sharey=ax1)
|
|
|
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass != 3].value_counts().plot(kind='bar',label='male, high class',color='lightblue')
|
|
|
ax3.set_xticklabels([u"未获救", u"获救"], rotation=0)
|
|
|
plt.legend([u"男性/高级舱"], loc='best')
|
|
|
|
|
|
ax4 = fig.add_subplot(144, sharey=ax1)
|
|
|
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass == 3].value_counts().plot(kind='bar', label='male low class',color='steelblue')
|
|
|
ax4.set_xticklabels([u"未获救", u"获救"], rotation=0)
|
|
|
plt.legend([u"男性/低级舱"], loc='best')
|
|
|
plt.savefig(path + '/p_class_set.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
def age(train_1, path):
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
|
|
pd.set_option('display.max_columns', None)
|
|
|
pd.set_option('display.max_rows', None)
|
|
|
pd.set_option('max_colwidth', 100)
|
|
|
pd.set_option('display.width', 1000)
|
|
|
fig = plt.figure(figsize=(8, 5))
|
|
|
fig.suptitle(u"各等级的乘客年龄分布")
|
|
|
ax = fig.add_subplot(111)
|
|
|
for pclass in [1, 2, 3]:
|
|
|
sns.kdeplot(train_1[train_1['Pclass'] == pclass]['Age'], ax=ax, label=f'Pclass {pclass}')
|
|
|
ax.set_xlabel(u"年龄")
|
|
|
ax.set_ylabel(u"密度")
|
|
|
plt.legend(loc='best')
|
|
|
plt.savefig(path + '/age_distribution.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
def carbin(full_1, path):
|
|
|
full_1['Has_Cabin'] = full_1['Cabin'].apply(lambda x: 0 if pd.isna(x) else 1)
|
|
|
sns.barplot(x='Has_Cabin', y='Survived', data=full_1)
|
|
|
plt.title('Survival Rate')
|
|
|
plt.savefig(path + '/survival_rate.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
cabin = pd.crosstab(full_1['Has_Cabin'], full_1['Survived'])
|
|
|
cabin.rename(index={0: 'no cabin', 1: 'have cabin'}, columns={0: 'Dead', 1: 'Survived'}, inplace=True)
|
|
|
cabin.plot.bar(figsize=(8, 5))
|
|
|
plt.xticks(rotation=0, size='large')
|
|
|
plt.title('Survived Count')
|
|
|
plt.xlabel('')
|
|
|
plt.legend()
|
|
|
plt.savefig(path + '/carbin.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
def age_scatter_plot(train_1, path):
|
|
|
plt.scatter(train_1['Age'], train_1['Survived'])
|
|
|
plt.ylabel(u"获救情况 (1为获救)")
|
|
|
plt.xlabel(u"年龄")
|
|
|
plt.title(u"年龄与获救情况散点图")
|
|
|
plt.savefig(path + '/age_scatter_plot.jpg')
|
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--dest_path", default='E:/python1/123', help="files save path")
|
|
|
parser.add_argument('--train', default='train.csv', help="train data")
|
|
|
parser.add_argument('--test', default='test.csv', help="test data")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
output_path = args.dest_path + '/data_output/'
|
|
|
set_outputfolder(output_path)
|
|
|
|
|
|
# 拿到测试集和训练集
|
|
|
train = pd.read_csv(args.train)
|
|
|
test = pd.read_csv(args.test)
|
|
|
full = pd.concat([train, test], ignore_index=True)
|
|
|
|
|
|
basic_show(train)
|
|
|
basic_show1(train, output_path)
|
|
|
age(train, output_path)
|
|
|
p_class_sex(train, output_path)
|
|
|
carbin(full, output_path)
|
|
|
age_scatter_plot(train, output_path) |