You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1.8 KiB

6.5 识别花朵

鸢尾花数据

鸢尾花数据集是一类多重变量分析的数据集,一共有150个样本,通过花萼长度花萼宽度花瓣长度花瓣宽度 4个特征预测鸢尾花卉属于(SetosaVersicolourVirginica)三个种类中的哪一类。

数据集中部分数据如下所示:

花萼长度 花萼宽度 花瓣长度 花瓣宽度
5.1 3.5 1.4 0.2
4.9 3.2 1.4 0.2
4.7 3.1 1.3 0.2

其中每一行代表一个鸢尾花样本各个属性的值。

数据集中部分标签如下图所示:

标签
0
1
2

标签中的值0,1,2分别代表鸢尾花三种不同的类别。

我们可以直接使用sklearn直接对数据进行加载,代码如下:

from sklearn.datasets import load_iris
#加载鸢尾花数据集
iris = load_iris()
#获取数据特征与标签
x,y = iris.data.astype(int),iris.target

然后我们划分出训练集与测试集,训练集用来训练模型,测试集用来检测模型性能。代码如下:

from sklearn.model_selection import train_test_split
#划分训练集测试集其中测试集样本数为整个数据集的20%
train_feature,test_feature,train_label,test_label = train_test_split(x,y,test_size=0.2,random_state=666)

进行分类

然后我们再使用实现的决策树分类方法就可以对测试集数据进行分类:

predict = dt_clf(train_feature,train_label,test_feature)
predict
>>>array([1, 2, 1, 2, 0, 1, 1, 2, 1, 1, 1, 0, 0, 0, 2, 1, 0, 2, 2, 2, 1, 0,2, 0, 1, 1, 0, 1, 2, 2])

再根据测试集标签,可以计算出正确率:

acc = np.mean(predict==test_label)
acc
>>>1.0

可以看到,使用决策树对鸢尾花进行分类,正确率可以达到100%