diff --git a/Three_iris.py b/Three_iris.py new file mode 100644 index 0000000..a49d406 --- /dev/null +++ b/Three_iris.py @@ -0,0 +1,152 @@ +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.naive_bayes import GaussianNB +from sklearn.tree import DecisionTreeClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.neighbors import KNeighborsClassifier +import numpy as np + + +def bayes_iris(x): + """ + 朴素贝叶斯对鸢尾花种类进行预测 + :param x: 预测数据 + :return: + """ + + #1)获取数据集 + iris = load_iris() + # print("查看特征值的名字:\n", iris.feature_names) + # print("查看特征值:\n", iris.data, iris.data.shape) + + #2)划分数据集 + x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=22) + + #3)建立模型 + estimator= GaussianNB(); + + #4)模型训练 + estimator.fit(x_train,y_train) + + #5)模型评估 + #对比真实值: + y_predict = estimator.predict(x_test) + # print("y_predict:\n",y_predict) + # print("直接对比真实值:\n",y_predict==y_test) + + #计算准确率 + score = estimator.score(x_test,y_test) + print("贝叶斯模型准确率为:\n",score) + + #7)做出预测 + #[0]表示setosa,[1]表示versicolor,[2]表示virginica + X_new = np.array(x) + prediction = estimator.predict(X_new) + print("贝叶斯模型预测的目标类别是:{}".format(prediction)) + print("贝叶斯模型预测的目标类别花名是:{}".format(iris['target_names'][prediction])) + print() + + return None + +def decisiontree_iris(x): + """ + 决策树对鸢尾花种类进行预测 + :param x: 预测数据 + :return: + """ + + # 1) 获取数据集 + iris = load_iris() + + # 2)划分数据集 + x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22,test_size=0.3) + + # 3) 决策树预估器 + estimator = DecisionTreeClassifier(criterion="entropy") + estimator.fit(x_train, y_train) + + # 4) 模型评估 + # 1、直接对比真实值和预测值 + y_predict = estimator.predict(x_test) + # print("y_predict:\n", y_predict) + # print("比较:\n", y_test == y_predict) + + # 2、计算准确率 + score = estimator.score(x_test, y_test) + print("决策树模型准确度为:\n", score) + + # 5) 做出预测 + X_new = np.array(x) + prediction = estimator.predict(X_new) + print("决策树模型预测目标类别是:{}".format(prediction)) + print("决策树模型预测目标花名是:", iris["target_names"][prediction]) + print() + + return None + +def knn_iris(x): + """ + KNN近邻对鸢尾花种类进行预测 + :param x:测试数据 + :return: + """ + # 1) 获取数据 + iris = load_iris() + + # 2)划分数据集 + x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22) + + # 3)特征工程:标准化 + transfer = StandardScaler() + x_train = transfer.fit_transform(x_train) + x_test = transfer.transform(x_test) + + # 4) KNN算法预估器 + estimator = KNeighborsClassifier(n_neighbors=3) + estimator.fit(x_train, y_train) + + # 5)模型估计 + # 方法1:直接比对真实值和预测值 + y_predict = estimator.predict(x_test) + # print("y_predict:\n", y_predict) + # print("直接比对真实值和预测值:\n", y_test ==y_predict) + + #方法2:计算准确值 + score = estimator.score(x_test, y_test) + print("KNN模型准确值为:\n", score) + + # 6) 做出预测 + X_new = np.array(x) + prediction = estimator.predict(X_new) + print("KNN模型预测目标类别是:{}".format(prediction)) + print("KNN模型预测目标花名是:", iris["target_names"][prediction]) + print() + + return None + +def input_feature(): + """ + 顺序输入花萼长度、宽度,花瓣长度、宽度 + :return:测试值二维数组 + """ + feature = [0] * 4 + name = ['Sepal Length:','Sepal Width:','Petal Length:','Petal Width:'] + predict = [] + count = 0 + for i in range(4): + try: + feature[i] = float(input(name[i])) + except ValueError: + print("请输入数值") + return None + else: + count += 1 + predict.append(feature) + return predict + +if __name__ == '__main__': + X_new = input_feature() + if X_new != None: + decisiontree_iris(X_new) + knn_iris(X_new) + bayes_iris(X_new) \ No newline at end of file