ADD file via upload

main
p7jo5irft 10 months ago
parent dbc3c9514d
commit 528141d9c3

@ -0,0 +1,153 @@
# -*- coding:utf-8 -*-
import argparse
import os.path
import matplotlib
matplotlib.use('TkAgg') # 例如使用TkAgg后端
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
def set_outputfolder(output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
def basic_show(train_1):
print('----------------------------------------------------')
print("数据分布情况:")
print(train_1.describe())
print('----------------------------------------------------')
print("数据类型观察:")
print(train_1.info())
print('----------------------------------------------------')
print("数据内容:")
print(train_1.head())
print('----------------------------------------------------')
print("原始数据缺失情况:")
print(train_1.isnull().sum())
def basic_show1(df, path):
# 性别对于生还的影响
sns.barplot(x="Sex", y="Survived", data=df)
plt.savefig(path + '/sex.jpg')
plt.show()
# 船舱等级对生还的影响
sns.barplot(x="Pclass", y="Survived", data=df)
plt.savefig(path + '/Pclass.jpg')
plt.show()
# 配偶及兄弟姐妹数适中的乘客幸存率更高
sns.barplot(x="SibSp", y="Survived", data=df)
plt.savefig(path + '/SibSp.jpg')
plt.show()
# 父母与子女数适中的乘客幸存率更高
sns.barplot(x="Parch", y="Survived", data=df)
plt.savefig(path + '/Parch.jpg')
plt.show()
# Embarked登港港口与生存情况的分析
sns.countplot(x='Embarked', hue='Survived', data=df, fill=True)
plt.savefig(path + '/Embarked.jpg')
plt.show()
def p_class_sex(train_1, path):
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)
pd.set_option('display.width', 1000)
fig = plt.figure() #图形对象Figure并将其保存在变量 fig 中
fig.set(alpha=0.5)
fig = plt.figure()
fig.set(alpha=0.65) # 设置图像透明度,无所谓
plt.title(u"根据舱等级和性别的获救情况")
ax1 = fig.add_subplot(141)#1*4
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass != 3].value_counts().plot(kind='bar',label="female highclass",color='#FA2479')
ax1.set_xticklabels([u"获救", u"未获救"], rotation=0)#x标签
ax1.legend([u"女性/高级舱"], loc='best')
ax2 = fig.add_subplot(142, sharey=ax1)
train_1.Survived[train_1.Sex == 'female'][train_1.Pclass == 3].value_counts().plot(kind='bar',label='female, low class',color='pink')
ax2.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"女性/低级舱"], loc='best')
ax3 = fig.add_subplot(143, sharey=ax1)
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass != 3].value_counts().plot(kind='bar',label='male, high class',color='lightblue')
ax3.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"男性/高级舱"], loc='best')
ax4 = fig.add_subplot(144, sharey=ax1)
train_1.Survived[train_1.Sex == 'male'][train_1.Pclass == 3].value_counts().plot(kind='bar', label='male low class',color='steelblue')
ax4.set_xticklabels([u"未获救", u"获救"], rotation=0)
plt.legend([u"男性/低级舱"], loc='best')
plt.savefig(path + '/p_class_set.jpg')
plt.show()
def age(train_1, path):
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)
pd.set_option('display.width', 1000)
fig = plt.figure(figsize=(8, 5))
fig.suptitle(u"各等级的乘客年龄分布")
ax = fig.add_subplot(111)
for pclass in [1, 2, 3]:
sns.kdeplot(train_1[train_1['Pclass'] == pclass]['Age'], ax=ax, label=f'Pclass {pclass}')
ax.set_xlabel(u"年龄")
ax.set_ylabel(u"密度")
plt.legend(loc='best')
plt.savefig(path + '/age_distribution.jpg')
plt.show()
def carbin(full_1, path):
full_1['Has_Cabin'] = full_1['Cabin'].apply(lambda x: 0 if pd.isna(x) else 1)
sns.barplot(x='Has_Cabin', y='Survived', data=full_1)
plt.title('Survival Rate')
plt.savefig(path + '/survival_rate.jpg')
plt.show()
cabin = pd.crosstab(full_1['Has_Cabin'], full_1['Survived'])
cabin.rename(index={0: 'no cabin', 1: 'have cabin'}, columns={0: 'Dead', 1: 'Survived'}, inplace=True)
cabin.plot.bar(figsize=(8, 5))
plt.xticks(rotation=0, size='large')
plt.title('Survived Count')
plt.xlabel('')
plt.legend()
plt.savefig(path + '/carbin.jpg')
plt.show()
def age_scatter_plot(train_1, path):
plt.scatter(train_1['Age'], train_1['Survived'])
plt.ylabel(u"获救情况 (1为获救)")
plt.xlabel(u"年龄")
plt.title(u"年龄与获救情况散点图")
plt.savefig(path + '/age_scatter_plot.jpg')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dest_path", default='E:/python1/123', help="files save path")
parser.add_argument('--train', default='train.csv', help="train data")
parser.add_argument('--test', default='test.csv', help="test data")
args = parser.parse_args()
output_path = args.dest_path + '/data_output/'
set_outputfolder(output_path)
# 拿到测试集和训练集
train = pd.read_csv(args.train)
test = pd.read_csv(args.test)
full = pd.concat([train, test], ignore_index=True)
basic_show(train)
basic_show1(train, output_path)
age(train, output_path)
p_class_sex(train, output_path)
carbin(full, output_path)
age_scatter_plot(train, output_path)
Loading…
Cancel
Save