You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
763 lines
38 KiB
763 lines
38 KiB
6 years ago
|
|
||
|
<!DOCTYPE HTML>
|
||
|
<html lang="" >
|
||
|
<head>
|
||
|
<meta charset="UTF-8">
|
||
|
<meta content="text/html; charset=utf-8" http-equiv="Content-Type">
|
||
|
<title>使用sklearn进行机器学习 · GitBook</title>
|
||
|
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
|
||
|
<meta name="description" content="">
|
||
|
<meta name="generator" content="GitBook 3.2.3">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="gitbook/style.css">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
6 years ago
|
<link rel="stylesheet" href="gitbook/gitbook-plugin-katex/katex.min.css">
|
||
|
|
||
|
|
||
|
|
||
6 years ago
|
<link rel="stylesheet" href="gitbook/gitbook-plugin-highlight/website.css">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="gitbook/gitbook-plugin-search/search.css">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="gitbook/gitbook-plugin-fontsettings/website.css">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<meta name="HandheldFriendly" content="true"/>
|
||
|
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no">
|
||
|
<meta name="apple-mobile-web-app-capable" content="yes">
|
||
|
<meta name="apple-mobile-web-app-status-bar-style" content="black">
|
||
|
<link rel="apple-touch-icon-precomposed" sizes="152x152" href="gitbook/images/apple-touch-icon-precomposed-152.png">
|
||
|
<link rel="shortcut icon" href="gitbook/images/favicon.ico" type="image/x-icon">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="prev" href="cluster_metrics.html" />
|
||
|
|
||
|
|
||
|
</head>
|
||
|
<body>
|
||
|
|
||
|
<div class="book">
|
||
|
<div class="book-summary">
|
||
|
|
||
|
|
||
|
<div id="book-search-input" role="search">
|
||
|
<input type="text" placeholder="Type to search" />
|
||
|
</div>
|
||
|
|
||
|
|
||
|
<nav role="navigation">
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="summary">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.1" data-path="./">
|
||
|
|
||
|
<a href="./">
|
||
|
|
||
|
|
||
|
简介
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.2" data-path="machine_learning.html">
|
||
|
|
||
|
<a href="machine_learning.html">
|
||
|
|
||
|
|
||
|
机器学习概述
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3" data-path="algorithm.html">
|
||
|
|
||
|
<a href="algorithm.html">
|
||
|
|
||
|
|
||
|
常见机器学习算法
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.3.1" data-path="kNN.html">
|
||
|
|
||
|
<a href="kNN.html">
|
||
|
|
||
|
|
||
|
近朱者赤近墨者黑-kNN
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.2" data-path="linear_regression.html">
|
||
|
|
||
|
<a href="linear_regression.html">
|
||
|
|
||
|
|
||
|
最简单的回归算法-线性回归
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.3" data-path="logistic_regression.html">
|
||
|
|
||
|
<a href="logistic_regression.html">
|
||
|
|
||
|
|
||
|
使用回归的思想进行分类-逻辑回归
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.4" data-path="decision_tree.html">
|
||
|
|
||
|
<a href="decision_tree.html">
|
||
|
|
||
|
|
||
|
最接近人类思维的算法-决策树
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.5" data-path="random_forest.html">
|
||
|
|
||
|
<a href="random_forest.html">
|
||
|
|
||
|
|
||
|
群众的力量是伟大的-随机森林
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.6" data-path="kMeans.html">
|
||
|
|
||
|
<a href="kMeans.html">
|
||
|
|
||
|
|
||
|
物以类聚人以群分-kMeans
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.7" data-path="AGNES.html">
|
||
|
|
||
|
<a href="AGNES.html">
|
||
|
|
||
|
|
||
|
以距离为尺-AGNES
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4" data-path="metrics.html">
|
||
|
|
||
|
<a href="metrics.html">
|
||
|
|
||
|
|
||
|
模型评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.4.1" data-path="classification_metrics.html">
|
||
|
|
||
|
<a href="classification_metrics.html">
|
||
|
|
||
|
|
||
|
分类性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4.2" data-path="regression_metrics.html">
|
||
|
|
||
|
<a href="regression_metrics.html">
|
||
|
|
||
|
|
||
|
回归性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4.3" data-path="cluster_metrics.html">
|
||
|
|
||
|
<a href="cluster_metrics.html">
|
||
|
|
||
|
|
||
|
聚类性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter active" data-level="1.5" data-path="sklearn.html">
|
||
|
|
||
|
<a href="sklearn.html">
|
||
|
|
||
|
|
||
|
使用sklearn进行机器学习
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
6 years ago
|
<li class="chapter " data-level="1.6" >
|
||
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
综合实战案例
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1" >
|
||
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
泰坦尼克生还预测
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1.1" data-path="titanic/introduction.html">
|
||
|
|
||
|
<a href="titanic/introduction.html">
|
||
|
|
||
|
|
||
|
简介
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1.2" data-path="titanic/EDA.html">
|
||
|
|
||
|
<a href="titanic/EDA.html">
|
||
|
|
||
|
|
||
|
探索性数据分析(EDA)
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1.3" data-path="titanic/feature engerning.html">
|
||
|
|
||
|
<a href="titanic/feature engerning.html">
|
||
|
|
||
|
|
||
|
特征工程
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1.4" data-path="titanic/fit and predict.html">
|
||
|
|
||
|
<a href="titanic/fit and predict.html">
|
||
|
|
||
|
|
||
|
构建模型进行预测
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1.5" data-path="titanic/tuning.html">
|
||
|
|
||
|
<a href="titanic/tuning.html">
|
||
|
|
||
|
|
||
|
调参
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.2" >
|
||
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
使用强化学习玩乒乓球游戏
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.6.2.1" data-path="pingpong/what is reinforce learning.html">
|
||
|
|
||
|
<a href="pingpong/what is reinforce learning.html">
|
||
|
|
||
|
|
||
|
什么是强化学习
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.2.2" data-path="pingpong/Policy Gradient.html">
|
||
|
|
||
|
<a href="pingpong/Policy Gradient.html">
|
||
|
|
||
|
|
||
|
Policy Gradient原理
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6.2.3" data-path="pingpong/coding.html">
|
||
|
|
||
|
<a href="pingpong/coding.html">
|
||
|
|
||
|
|
||
|
使用Policy Gradient玩乒乓球游戏
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.7" data-path="recommand.html">
|
||
|
|
||
|
<a href="recommand.html">
|
||
|
|
||
|
|
||
|
实训推荐
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
6 years ago
|
|
||
|
|
||
|
|
||
|
<li class="divider"></li>
|
||
|
|
||
|
<li>
|
||
|
<a href="https://www.gitbook.com" target="blank" class="gitbook-link">
|
||
|
Published with GitBook
|
||
|
</a>
|
||
|
</li>
|
||
|
</ul>
|
||
|
|
||
|
|
||
|
</nav>
|
||
|
|
||
|
|
||
|
</div>
|
||
|
|
||
|
<div class="book-body">
|
||
|
|
||
|
<div class="body-inner">
|
||
|
|
||
|
|
||
|
|
||
|
<div class="book-header" role="navigation">
|
||
|
|
||
|
|
||
|
<!-- Title -->
|
||
|
<h1>
|
||
|
<i class="fa fa-circle-o-notch fa-spin"></i>
|
||
|
<a href="." >使用sklearn进行机器学习</a>
|
||
|
</h1>
|
||
|
</div>
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<div class="page-wrapper" tabindex="-1" role="main">
|
||
|
<div class="page-inner">
|
||
|
|
||
|
<div id="book-search-results">
|
||
|
<div class="search-noresults">
|
||
|
|
||
|
<section class="normal markdown-section">
|
||
|
|
||
|
<h1 id="使用sklearn进行机器学习">使用sklearn进行机器学习</h1>
|
||
|
<h2 id="写在前面">写在前面</h2>
|
||
|
<p>这是一个 sklearn 的 hello world 级教程,想要更加系统更加全面的学习 sklearn 建议查阅 sklearn 的<a href="https://scikit-learn.org/stable/" target="_blank">官方网站</a>。</p>
|
||
|
<h2 id="sklearn简介">sklearn简介</h2>
|
||
|
<p>scikit-learn(简记sklearn),是用 python 实现的机器学习算法库。sklearn 可以实现数据预处理、分类、回归、降维、模型选择等常用的机器学习算法。基本上只需要知道一些 python 的基础语法知识就能学会怎样使用 sklearn 了,所以 sklearn 是一款非常好用的 python 机器学习库。</p>
|
||
|
<h2 id="sklearn的安装">sklearn的安装</h2>
|
||
|
<p>和安装其他第三方库一样简单,只需要在命令行中输入 <code>pip install scikit-learn</code> 即可。</p>
|
||
|
<h2 id="sklearn的目录结构">sklearn的目录结构</h2>
|
||
|
<p>sklearn 提供的接口都封装在不同的目录下的不同的 py 文件中,所以对 sklearn 的目录结构有一个大致的了解,有助于我们更加深刻地理解 sklearn 。目录结构如下:</p>
|
||
|
<p><img src="img/29.jpg" alt=""></p>
|
||
|
<p>其实从目录名字可以看出目录中的 py 文件是干啥的。比如 cluster 目录下都是聚类算法接口, ensem 目录下都是集成学习算法的接口。</p>
|
||
|
<h2 id="使用sklearn识别手写数字">使用sklearn识别手写数字</h2>
|
||
|
<p>接下来不如通过一个实例来感受一下 sklearn 的强大。</p>
|
||
|
<p>想要识别手写数字,首先需要有数据。sklearn 中已经为我们准备好了一些比较经典且质量较高的数据集,其中就包括手写数字数据集。该数据集有 1797 个样本,每个样本包括 8*8 像素(实际上是一条样本有 64 个特征,每个像素看成是一个特征,每个特征都是 float 类型的数值)的图像和一个 [0, 9] 整数的标签。比如下图的标签是 2 :</p>
|
||
|
<p><img src="img/31.jpg" alt=""></p>
|
||
|
<p>想要使用这个数据很简单,代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn <span class="hljs-keyword">import</span> datasets
|
||
|
|
||
|
<span class="hljs-comment"># 加载手写数字数据集</span>
|
||
|
digits = datasets.load_digits()
|
||
|
|
||
|
<span class="hljs-comment"># X表示特征,即1797行64列的矩阵</span>
|
||
|
X = digits.data
|
||
|
<span class="hljs-comment"># Y表示标签,即1797个元素的一维数组</span>
|
||
|
y = digits.target
|
||
|
</code></pre>
|
||
|
<p>得到 X,y 数据之后,我们还需要将这些数据进行划分,划分成两个部分,一部分是训练集,另一部分是测试集。因为如果没有测试集的话,我们并不知道我们的手写数字识别程序识别得准不准。数据集划分代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 将X,y划分成训练集和测试集,其中训练集的比例为80%,测试集的比例为20%</span>
|
||
|
<span class="hljs-comment"># X_train表示训练集的特征,X_test表示测试集的特征,y_train表示训练集的标签,y_test表示测试集的标签</span>
|
||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=<span class="hljs-number">0.2</span>)
|
||
|
</code></pre>
|
||
|
<p>接下来,可以使用机器学习算法来实现手写数字识别了,例如想要使用随机森林来进行识别,那么首先要导入随机森林算法接口。</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 由于是分类问题,所以导入的是RandomForestClassifier</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier
|
||
|
</code></pre>
|
||
|
<p>导入好接口后,就可以创建随机森林对象了。随机森林对象有用来训练的函数 <code>fit</code> 和用来预测的函数 <code>predict</code>。<code>fit</code>函数需要训练集的特征和训练集的标签作为输入,<code>predict</code>函数需要测试集的特征作为输入。所以代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 创建一个有50棵决策树的随机森林, n_estimators表示决策树的数量</span>
|
||
|
clf = RandomForestClassifier(n_estimators=<span class="hljs-number">50</span>)
|
||
|
<span class="hljs-comment"># 用训练集训练</span>
|
||
|
clf.fit(X_train, Y_train)
|
||
|
<span class="hljs-comment"># 用测试集测试,result为预测结果</span>
|
||
|
result = clf.predict(X_test)
|
||
|
</code></pre>
|
||
|
<p>得到预测结果后,我们需要将其与测试集的真实答案进行比对,计算出预测的准确率。sklearn 已经为我们提供了计算准确率的接口,使用代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 导入计算准确率的接口</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score
|
||
|
|
||
|
<span class="hljs-comment"># 计算预测准确率</span>
|
||
|
acc = accuracy_score(y_test, result)
|
||
|
<span class="hljs-comment"># 打印准确率</span>
|
||
|
print(acc)
|
||
|
</code></pre>
|
||
|
<p>此时您会发现我们短短的几行代码实现的手写数字识别程序的准确率高于 <strong>0.95</strong>。</p>
|
||
|
<p>而且我们不仅可以使用随机森林来实现手写数字识别,我们还可以使用别的机器学习算法实现,比如逻辑回归,代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.linear_model <span class="hljs-keyword">import</span> LogisticRegression
|
||
|
|
||
|
<span class="hljs-comment"># 创建一个逻辑回归对象</span>
|
||
|
clf = LogisticRegression()
|
||
|
<span class="hljs-comment"># 用训练集训练</span>
|
||
|
clf.fit(X_train, Y_train)
|
||
|
<span class="hljs-comment"># 用测试集测试,result为预测结果</span>
|
||
|
result = clf.predict(X_test)
|
||
|
</code></pre>
|
||
|
<p>细心的您可能已经发现,不管使用哪种分类算法来进行手写数字识别,不同的只是创建的算法对象不一样而已。有了算法对象后,就可以<code>fit</code>,<code>predict</code>大法了。</p>
|
||
|
<p>下面是使用随机森林识别手写数字的完整代码:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn <span class="hljs-keyword">import</span> datasets
|
||
|
<span class="hljs-comment"># 由于是分类问题,所以导入的是RandomForestClassifier</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier
|
||
|
<span class="hljs-comment"># 导入计算准确率的接口</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score
|
||
|
|
||
|
<span class="hljs-comment"># 加载手写数字数据集</span>
|
||
|
digits = datasets.load_digits()
|
||
|
|
||
|
<span class="hljs-comment"># X表示特征,即1797行64列的矩阵</span>
|
||
|
X = digits.data
|
||
|
<span class="hljs-comment"># Y表示标签,即1797个元素的一维数组</span>
|
||
|
y = digits.target
|
||
|
|
||
|
<span class="hljs-comment"># 将X,y划分成训练集和测试集,其中训练集的比例为80%,测试集的比例为20%</span>
|
||
|
<span class="hljs-comment"># X_train表示训练集的特征,X_test表示测试集的特征,y_train表示训练集的标签,y_test表示测试集的标签</span>
|
||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=<span class="hljs-number">0.2</span>)
|
||
|
|
||
|
<span class="hljs-comment"># 创建一个有50棵决策树的随机森林, n_estimators表示决策树的数量</span>
|
||
|
clf = RandomForestClassifier(n_estimators=<span class="hljs-number">50</span>)
|
||
|
<span class="hljs-comment"># 用训练集训练</span>
|
||
|
clf.fit(X_train, Y_train)
|
||
|
<span class="hljs-comment"># 用测试集测试,result为预测结果</span>
|
||
|
result = clf.predict(X_test)
|
||
|
|
||
|
<span class="hljs-comment"># 计算预测准确率</span>
|
||
|
acc = accuracy_score(y_test, result)
|
||
|
<span class="hljs-comment"># 打印准确率</span>
|
||
|
print(acc)
|
||
|
</code></pre>
|
||
|
<h2 id="更好地验证算法性能">更好地验证算法性能</h2>
|
||
|
<p>在划分训练集与测试集时会有这样的情况,可能模型对于数字 1 的识别准确率比较低 ,而测试集中没多少个数字为 1 的样本,然后用测试集测试完后得到的准确率为 0.96 。然后您可能觉得哎呀,我的模型很厉害了,但其实并不然,因为这样的测试集让您的模型的性能有了误解。那有没有更加公正的验证算法性能的方法呢?有,那就是<strong>k-折验证</strong>!</p>
|
||
|
<p><strong>k-折验证</strong>的大体思路是将整个数据集分成 k 份,然后试图让每一份子集都能成为测试集,并循环 k 次,总后计算 k 次模型的性能的平均值作为性能的估计。一般来说 k 的值为 5 或者 10。</p>
|
||
|
<p><strong>k-折验证</strong>的流程如下:</p>
|
||
|
<ol>
|
||
|
<li>不重复抽样将整个数据集随机拆分成 k 份</li>
|
||
|
<li>每一次挑选其中 1 份作为测试集,剩下的 k-1 份作为训练集
|
||
|
2.1. 在每个训练集上训练后得到一个模型
|
||
|
2.2. 用这个模型在相应的测试集上测试,计算并保存模型的评估指标</li>
|
||
|
<li>重复第 2 步 k 次,这样每份都有一次机会作为测试集,其他机会作为训练集</li>
|
||
|
<li>计算 k 组测试结果的平均值作为算法性能的估计。</li>
|
||
|
</ol>
|
||
|
<p>sklearn 为我们提供了将数据划分成 k 份类 KFold ,使用示例如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 导入KFold</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> KFold
|
||
|
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier
|
||
|
<span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score
|
||
|
|
||
|
<span class="hljs-comment"># 创建一个将数据集随机划分成5份</span>
|
||
|
kf = KFold(n_splits = <span class="hljs-number">5</span>)
|
||
|
|
||
|
mean_acc = <span class="hljs-number">0</span>
|
||
|
|
||
|
<span class="hljs-comment"># 将整个数据集划分成5份</span>
|
||
|
<span class="hljs-comment"># train_index表示从5份中挑出来4份所拼出来的训练集的索引</span>
|
||
|
<span class="hljs-comment"># test_index表示剩下的一份作为测试集的索引</span>
|
||
|
<span class="hljs-keyword">for</span> train_index, test_index <span class="hljs-keyword">in</span> kf.split(X):
|
||
|
X_train, y_train = X[train_index], y[train_index]
|
||
|
X_test, y_test = X[test_index], y[test_index]
|
||
|
rf = RandomForestClassifier()
|
||
|
rf.fit(X_train, y_train)
|
||
|
result = rf.predict(X_test)
|
||
|
mean_acc = accuracy_score(y_test, result)
|
||
|
|
||
|
<span class="hljs-comment"># 打印5折验证的平均准确率</span>
|
||
|
print(mean_acc/<span class="hljs-number">5</span>)
|
||
|
</code></pre>
|
||
|
<p>完整代码如下:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn <span class="hljs-keyword">import</span> datasets
|
||
|
<span class="hljs-comment"># 由于是分类问题,所以导入的是RandomForestClassifier</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier
|
||
|
<span class="hljs-comment"># 导入计算准确率的接口</span>
|
||
|
<span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score
|
||
|
<span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> KFold
|
||
|
|
||
|
<span class="hljs-comment"># 加载手写数字数据集</span>
|
||
|
digits = datasets.load_digits()
|
||
|
|
||
|
<span class="hljs-comment"># X表示特征,即1797行64列的矩阵</span>
|
||
|
X = digits.data
|
||
|
<span class="hljs-comment"># Y表示标签,即1797个元素的一维数组</span>
|
||
|
y = digits.target
|
||
|
|
||
|
<span class="hljs-comment"># 创建一个将数据集随机划分成5份</span>
|
||
|
kf = KFold(n_splits = <span class="hljs-number">5</span>)
|
||
|
|
||
|
mean_acc = <span class="hljs-number">0</span>
|
||
|
|
||
|
<span class="hljs-comment"># 将整个数据集划分成5份</span>
|
||
|
<span class="hljs-comment"># train_index表示从5份中挑出来4份所拼出来的训练集的索引</span>
|
||
|
<span class="hljs-comment"># test_index表示剩下的一份作为测试集的索引</span>
|
||
|
<span class="hljs-keyword">for</span> train_index, test_index <span class="hljs-keyword">in</span> kf.split(X):
|
||
|
X_train, y_train = X[train_index], y[train_index]
|
||
|
X_test, y_test = X[test_index], y[test_index]
|
||
|
rf = RandomForestClassifier()
|
||
|
rf.fit(X_train, y_train)
|
||
|
result = rf.predict(X_test)
|
||
|
mean_acc = accuracy_score(y_test, result)
|
||
|
|
||
|
<span class="hljs-comment"># 打印5折验证的平均准确率</span>
|
||
|
print(mean_acc/<span class="hljs-number">5</span>)
|
||
|
</code></pre>
|
||
|
|
||
|
|
||
|
</section>
|
||
|
|
||
|
</div>
|
||
|
<div class="search-results">
|
||
|
<div class="has-results">
|
||
|
|
||
|
<h1 class="search-results-title"><span class='search-results-count'></span> results matching "<span class='search-query'></span>"</h1>
|
||
|
<ul class="search-results-list"></ul>
|
||
|
|
||
|
</div>
|
||
|
<div class="no-results">
|
||
|
|
||
|
<h1 class="search-results-title">No results matching "<span class='search-query'></span>"</h1>
|
||
|
|
||
|
</div>
|
||
|
</div>
|
||
|
</div>
|
||
|
|
||
|
</div>
|
||
|
</div>
|
||
|
|
||
|
</div>
|
||
|
|
||
|
|
||
|
|
||
|
<a href="cluster_metrics.html" class="navigation navigation-prev navigation-unique" aria-label="Previous page: 聚类性能评估指标">
|
||
|
<i class="fa fa-angle-left"></i>
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
</div>
|
||
|
|
||
|
<script>
|
||
|
var gitbook = gitbook || [];
|
||
|
gitbook.push(function() {
|
||
6 years ago
|
gitbook.page.hasChanged({"page":{"title":"使用sklearn进行机器学习","level":"1.5","depth":1,"next":{"title":"综合实战案例","level":"1.6","depth":1,"ref":"","articles":[{"title":"泰坦尼克生还预测","level":"1.6.1","depth":2,"ref":"","articles":[{"title":"简介","level":"1.6.1.1","depth":3,"path":"titanic/introduction.md","ref":"./titanic/introduction.md","articles":[]},{"title":"探索性数据分析(EDA)","level":"1.6.1.2","depth":3,"path":"titanic/EDA.md","ref":"./titanic/EDA.md","articles":[]},{"title":"特征工程","level":"1.6.1.3","depth":3,"path":"titanic/feature engerning.md","ref":"./titanic/feature engerning.md","articles":[]},{"title":"构建模型进行预测","level":"1.6.1.4","depth":3,"path":"titanic/fit and predict.md","ref":"./titanic/fit and predict.md","articles":[]},{"title":"调参","level":"1.6.1.5","depth":3,"path":"titanic/tuning.md","ref":"./titanic/tuning.md","articles":[]}]},{"title":"使用强化学习玩乒乓球游戏","level":"1.6.2","depth":2,"ref":"","articles":[{"title":"什么是强化学习","level":"1.6.2.1","depth":3,"path":"pingpong/what is reinforce learning.md","ref":"./pingpong/what is reinforce learning.md","articles":[]},{"title":"Policy Gradient原理","level":"1.6.2.2","depth":3,"path":"pingpong/Policy Gradient.md","ref":"./pingpong/Policy Gradient.md","articles":[]},{"title":"使用Policy Gradient玩乒乓球游戏","level":"1.6.2.3","depth":3,"path":"pingpong/coding.md","ref":"./pingpong/coding.md","articles":[]}]}]},"previous":{"title":"聚类性能评估指标","level":"1.4.3","depth":2,"path":"cluster_metrics.md","ref":"cluster_metrics.md","articles":[]},"dir":"ltr"},"config":{"gitbook":"*","theme":"default","variables":{},"plugins":["katex"],"pluginsConfig":{"katex":{},"highlight":{},"search":{},"lunr":{"maxIndexSize":1000000,"ignoreSpecialCharacters":false},"sharing":{"facebook":true,"twitter":true,"google":false,"weibo":false,"instapaper":false,"vk":false,"all":["facebook","google","twitter","weibo","instapaper"]},"fontsettings":{"theme":"white","family":"sans","size":2},"theme-default":{"styles":{"website":"styles/website.css","pdf":"styles/pdf.css","epub":"styles/epub.css","mobi":"styles/mobi.css","ebook":"styles/ebook.css","print":"styles/print.css"},"showLevel":false}},"structure":{"langs":"LANGS.md","readme":"README.md","glossary":"GLOSSARY.md","summary":"SUMMARY.md"},"pdf":{"pageNumbers":true,"fontSize":12,"fontFamily":"Arial","paperSize":"a4","chapterMark":"pagebreak","pageBreaksBefore":"/","margin":{"right":62,"left":62,"top":56,"bottom":56}},"styles":{"website":"styles/website.css","pdf":"styles/pdf.css","epub":"styles/epub.css","mobi":"styles/mobi.css","ebook":"styles/ebook.css","print":"styles/print.css"}},"file":{"path":"sklearn.md","mtime":"2019-07-04T06:39:44.002Z","type":"markdown"},"gitbook":{"version":"3.2.3","time":"2019-07-06T07:31:21.537Z"},"basePath":".","book":{"language":""}});
|
||
6 years ago
|
});
|
||
|
</script>
|
||
|
</div>
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook.js"></script>
|
||
|
<script src="gitbook/theme.js"></script>
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-search/search-engine.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-search/search.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-lunr/lunr.min.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-lunr/search-lunr.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-sharing/buttons.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="gitbook/gitbook-plugin-fontsettings/fontsettings.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
</body>
|
||
|
</html>
|
||
|
|