Delete 'knn_iris.py'

文档及讲解
Q8xg5nefi 4 years ago
parent 9ddf1302e5
commit 7836bff813

@ -1,48 +0,0 @@
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
def knn_iris():
"""
用KNN算法对鸢尾花进行预测
:return:
"""
# 1) 获取数据
iris = load_iris()
# 2划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)
# 3特征工程标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4) KNN算法预估器
estimator = KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train, y_train)
# 5模型估计
# 方法1直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test ==y_predict)
#方法2计算准确值
score = estimator.score(x_test, y_test)
print("准确值为:\n", score)
# 6) 做出预测
X_new = np.array([[1.1, 5.9, 1.4, 2.2]])
prediction = estimator.predict(X_new)
print("预测目标类别是:{}".format(prediction))
print("预测目标花名是:", iris["target_names"][prediction])
return None
if __name__ =="__main__":
knn_iris()
Loading…
Cancel
Save