diff --git a/knn_iris.py b/knn_iris.py deleted file mode 100644 index db25417..0000000 --- a/knn_iris.py +++ /dev/null @@ -1,48 +0,0 @@ -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler -from sklearn.neighbors import KNeighborsClassifier -import numpy as np - -def knn_iris(): - """ - 用KNN算法对鸢尾花进行预测 - :return: - """ - # 1) 获取数据 - iris = load_iris() - - # 2)划分数据集 - x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22) - - # 3)特征工程:标准化 - transfer = StandardScaler() - x_train = transfer.fit_transform(x_train) - x_test = transfer.transform(x_test) - - # 4) KNN算法预估器 - estimator = KNeighborsClassifier(n_neighbors=3) - estimator.fit(x_train, y_train) - - # 5)模型估计 - # 方法1:直接比对真实值和预测值 - y_predict = estimator.predict(x_test) - print("y_predict:\n", y_predict) - print("直接比对真实值和预测值:\n", y_test ==y_predict) - - #方法2:计算准确值 - score = estimator.score(x_test, y_test) - print("准确值为:\n", score) - - # 6) 做出预测 - X_new = np.array([[1.1, 5.9, 1.4, 2.2]]) - prediction = estimator.predict(X_new) - print("预测目标类别是:{}".format(prediction)) - print("预测目标花名是:", iris["target_names"][prediction]) - - return None - -if __name__ =="__main__": - knn_iris() - -