You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

153 lines
5.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding:utf-8 -*-
import argparse
import os.path
import matplotlib
matplotlib.use('TkAgg') # 例如使用TkAgg后端
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
def set_outputfolder(output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
def basic_show(train_1):
print('----------------------------------------------------')
print("数据分布情况:")
print(train_1.describe())
print('----------------------------------------------------')
print("数据类型观察:")
print(train_1.info())
print('----------------------------------------------------')
print("数据内容:")
print(train_1.head())
print('----------------------------------------------------')
print("原始数据缺失情况:")
print(train_1.isnull().sum())
def basic_show1(df, path):
# 性别对于生还的影响
sns.barplot(x="Sex", y="Survived", data=df)
plt.savefig(path + '/sex.jpg')
plt.show()
# 船舱等级对生还的影响
sns.barplot(x="Pclass", y="Survived", data=df)
plt.savefig(path + '/Pclass.jpg')
plt.show()
# 配偶及兄弟姐妹数适中的乘客幸存率更高
sns.barplot(x="SibSp", y="Survived", data=df)
plt.savefig(path + '/SibSp.jpg')
plt.show()
# 父母与子女数适中的乘客幸存率更高
sns.barplot(x="Parch", y="Survived", data=df)
plt.savefig(path + '/Parch.jpg')
plt.show()
# Embarked登港港口与生存情况的分析
sns.countplot(x='Embarked', hue='Survived', data=df, fill=True)
plt.savefig(path + '/Embarked.jpg')
plt.show()
def p_class_sex(train_1, path):
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)
pd.set_option('display.width', 1000)
fig = plt.figure() #图形对象Figure并将其保存在变量 fig 中
fig.set(alpha=0.5)
fig = plt.figure()
fig.set(alpha=0.65) # 设置图像透明度,无所谓
plt.title(u"根据舱等级和性别的获救情况")
ax1 = fig.add_subplot(141)#1*4
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass != 3].value_counts().plot(kind='bar',label="female highclass",color='#FA2479')
ax1.set_xticklabels([u"获救", u"未获救"], rotation=0)#x标签
ax1.legend([u"女性/高级舱"], loc='best')
ax2 = fig.add_subplot(142, sharey=ax1)
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass == 3].value_counts().plot(kind='bar',label='female, low class',color='pink')
ax2.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"女性/低级舱"], loc='best')
ax3 = fig.add_subplot(143, sharey=ax1)
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass != 3].value_counts().plot(kind='bar',label='male, high class',color='lightblue')
ax3.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"男性/高级舱"], loc='best')
ax4 = fig.add_subplot(144, sharey=ax1)
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass == 3].value_counts().plot(kind='bar', label='male low class',color='steelblue')
ax4.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"男性/低级舱"], loc='best')
plt.savefig(path + '/p_class_set.jpg')
plt.show()
def age(train_1, path):
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)
pd.set_option('display.width', 1000)
fig = plt.figure(figsize=(8, 5))
fig.suptitle(u"各等级的乘客年龄分布")
ax = fig.add_subplot(111)
for pclass in [1, 2, 3]:
sns.kdeplot(train_1[train_1['Pclass'] == pclass]['Age'], ax=ax, label=f'Pclass {pclass}')
ax.set_xlabel(u"年龄")
ax.set_ylabel(u"密度")
plt.legend(loc='best')
plt.savefig(path + '/age_distribution.jpg')
plt.show()
def carbin(full_1, path):
full_1['Has_Cabin'] = full_1['Cabin'].apply(lambda x: 0 if pd.isna(x) else 1)
sns.barplot(x='Has_Cabin', y='Survived', data=full_1)
plt.title('Survival Rate')
plt.savefig(path + '/survival_rate.jpg')
plt.show()
cabin = pd.crosstab(full_1['Has_Cabin'], full_1['Survived'])
cabin.rename(index={0: 'no cabin', 1: 'have cabin'}, columns={0: 'Dead', 1: 'Survived'}, inplace=True)
cabin.plot.bar(figsize=(8, 5))
plt.xticks(rotation=0, size='large')
plt.title('Survived Count')
plt.xlabel('')
plt.legend()
plt.savefig(path + '/carbin.jpg')
plt.show()
def age_scatter_plot(train_1, path):
plt.scatter(train_1['Age'], train_1['Survived'])
plt.ylabel(u"获救情况 (1为获救)")
plt.xlabel(u"年龄")
plt.title(u"年龄与获救情况散点图")
plt.savefig(path + '/age_scatter_plot.jpg')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dest_path", default='E:/python1/123', help="files save path")
parser.add_argument('--train', default='train.csv', help="train data")
parser.add_argument('--test', default='test.csv', help="test data")
args = parser.parse_args()
output_path = args.dest_path + '/data_output/'
set_outputfolder(output_path)
# 拿到测试集和训练集
train = pd.read_csv(args.train)
test = pd.read_csv(args.test)
full = pd.concat([train, test], ignore_index=True)
basic_show(train)
basic_show1(train, output_path)
age(train, output_path)
p_class_sex(train, output_path)
carbin(full, output_path)
age_scatter_plot(train, output_path)