You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2.7 KiB

构建模型进行预测

做好数据预处理后,可以将数据喂给机器学习模型来进行训练和预测了。不过在构建模型之前,要使用处理训练集数据的方式来处理测试集。

test_data=pd.read_csv('./Titanic/test.csv')

test_data['Initial']=0
for i in test_data:
    test_data.loc[:, 'Initial'] = test_data.Name.str.extract('([A-Za-z]+)\.',expand=False) #lets extract the Salutations

test_data.loc[:, 'Initial'].replace(['Mlle','Mme','Ms','Dr','Major','Lady','Countess','Jonkheer','Col','Rev','Capt','Sir','Don'],['Miss','Miss','Miss','Other','Mr','Mrs','Mrs','Other','Other','Other','Mr','Mr','Mr'],inplace=True)

test_data.loc[(test_data.Age.isnull())&(test_data.Initial=='Mr'),'Age']=33
test_data.loc[(test_data.Age.isnull())&(test_data.Initial=='Mrs'),'Age']=36
test_data.loc[(test_data.Age.isnull())&(test_data.Initial=='Miss'),'Age']=22
test_data.loc[(test_data.Age.isnull())&(test_data.Initial=='Other'),'Age']=46

test_data['Embarked'].fillna('S', inplace=True)

test_data['Age_band']=0
test_data.loc[test_data['Age']<=16,'Age_band']=0
test_data.loc[(test_data['Age']>16)&(test_data['Age']<=32),'Age_band']=1
test_data.loc[(test_data['Age']>32)&(test_data['Age']<=48),'Age_band']=2
test_data.loc[(test_data['Age']>48)&(test_data['Age']<=64),'Age_band']=3
test_data.loc[test_data['Age']>64,'Age_band']=4

test_data['Family_Size']=0
test_data['Family_Size']=test_data['Parch']+test_data['SibSp']+1
test_data['Alone']=0
test_data.loc[test_data.Family_Size==1,'Alone']=1

test_data['Fare_cat']=0
test_data.loc[test_data['Fare']<=7.91,'Fare_cat']=0
test_data.loc[(test_data['Fare']>7.91)&(test_data['Fare']<=14.454),'Fare_cat']=1
test_data.loc[(test_data['Fare']>14.454)&(test_data['Fare']<=31),'Fare_cat']=2
test_data.loc[(test_data['Fare']>31)&(test_data['Fare']<=513),'Fare_cat']=3

test_data['Sex'].replace(['male','female'],[0,1],inplace=True)
test_data['Embarked'].replace(['S','C','Q'],[0,1,2],inplace=True)
test_data['Initial'].replace(['Mr','Mrs','Miss','Master','Other'],[0,1,2,3,4],inplace=True)
test_data['Cabin'].replace(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'T'], [0, 1, 2, 3, 4, 5, 6, 7], inplace=True)

test_data.drop(['Name','Age','Ticket','Fare','Fare_Range','PassengerId'],axis=1,inplace=True)

然后可以使用机器学习模型来训练并预测了,这里使用的是随机森林。

Y_train = data['Survived']
X_train = data.drop(['Survived'], axis=1)

Y_test = test_data['Survived']
X_test = test_data.drop(['Survived'], axis=1)

clf = RandomForestClassifier(n_estimators=10)
clf.fit(X_train, Y_train)
predict = clf.predict(X_test)
print(accuracy_score(Y_test, predict))

此时看到预测的准确率达到了 0.8275 。