You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
7.0 KiB

import tensorflow as tf
import numpy as np
from sklearn import datasets
from matplotlib import pyplot as plt
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
import tkinter as tk
class evaclassy():
def __init__(self):
self.data = datasets.load_iris().data
self.target = datasets.load_iris().target
self.train_loss_list = []
self.test_acc_list = []
self.k1=0
self.b1=0
self.lr=0.1 #学习率设置
self.epoch=200
self.eval_loss_all=0 # 初始化loss_all的值用于记录每轮四个step生成的4个loss的和
def main(self):
#------------------------------------------------------------------------#
# ----------------------数据处理-----------------------------#
np.random.seed(116) # 使用用一个种子 保持输入特征与标签对应
np.random.shuffle(self.data)
np.random.seed(116)
np.random.shuffle(self.target)
tf.random.set_seed(116)
data_train = self.data[:-30] # 使用切片使前120组数据作为训练集后30组数据作为验证集
data_test = self.data[-30:]
target_train = self.target[:-30]
target_test = self.target[-30:]
data_train = tf.cast(data_train, tf.float32)
data_test = tf.cast(data_test, tf.float32)
train_db = tf.data.Dataset.from_tensor_slices((data_train, target_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((data_test, target_test)).batch(32)
k = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
#---------------------------------------------------------------------------------------------#
for self.epoch in range(self.epoch):
for step, (data_train, target_train) in enumerate(train_db):
with tf.GradientTape() as tape:
dat = tf.matmul(data_train, k) + b
dat = tf.nn.softmax(dat) # 使输出结果符合概率分布
targ = tf.one_hot(target_train, depth=3) # 将标签转化为独热码格式
loss = tf.reduce_mean(tf.square(targ - dat)) # 使用均方差损失函数mse计算损失函数
self.eval_loss_all += loss.numpy()
grads = tape.gradient(loss, [k, b]) # 计算loss对各个参数的梯度
k.assign_sub(self.lr * grads[0])
b.assign_sub(self.lr * grads[1]) # 更新模型偏置量参数b
print("Epoch: {}, loss: {}".format(self.epoch, self.eval_loss_all / 4))
self.train_loss_list.append(self.eval_loss_all / 4) # 记录loss_all均值放入列表
self.eval_loss_all = 0 # 归零便于记录下一个epoch的loss
total_correct, total_number = 0, 0
with open('k.txt','w') as f:
f.write(str(k))
with open('b.txt', 'w') as f:
f.write(str(b))
self.k1=k
self.b1=b
for data_test, target_test in test_db:
dat = tf.matmul(data_test, k) + b
dat = tf.nn.softmax(dat)
pred = tf.argmax(dat, axis=1) # 返回y中最大值的索引即鸢尾花的分类标签
pred = tf.cast(pred, dtype=target_test.dtype) # 转换数据类型
correct = tf.cast(tf.equal(pred, target_test), dtype=tf.int32) # 根据分类是否正确返回布尔 # 值且转换为int型
correct = tf.reduce_sum(correct)
total_correct += int(correct)
total_number += data_test.shape[0]
acc = total_correct / total_number # 总正确次数/总预测次数,计算准确率
self.test_acc_list.append(acc) # 添加准确率数据到列表记录下来
print("acc: ", acc)
def draw(self):
plt.title('Acc Curve') # 图片标题
plt.xlabel('迭代次数', fontproperties='SimHei', fontsize=15)
plt.ylabel('准确率', fontproperties='SimHei', fontsize=15)
plt.plot(self.test_acc_list, label="$Accuracy$")
plt.legend()
plt.savefig('准确率图像')#图片保存
plt.show()
plt.title('Loss Function Curve')
plt.xlabel('迭代次数', fontproperties='SimHei', fontsize=15)
plt.ylabel('损失率', fontproperties='SimHei', fontsize=15)
plt.plot(self.train_loss_list, label="$Loss$")
plt.legend()
plt.savefig('损失率图像')#图片保存
plt.show()
def predict(self,data):
y = tf.matmul(data, self.k1) + self.b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1)
pred=int(pred)
if pred==0:
print('是山鸢尾花')
return '山鸢尾花'
if pred==1:
print('是变色鸢尾花')
return '是变色鸢尾花'
if pred==2:
print('是维吉尼亚鸢尾花')
return '是维吉尼亚鸢尾花'
#数据处理
def dataloader(list1):
data_train=[]
data_train.append(list1)
data_train = tf.cast(data_train, tf.float32)
return data_train
#窗口展示
class draw1():
def insert_point(self):
var = self.e.get()
c=str(var)
var = list(var)
var=list(map(float,var))
print(var)
self.v1.set('')
var=dataloader(var)
try:
predic=a.predict(var)
except:
self.t.delete('1.0', 'end')
self.t.insert('insert', '请重新输入')
else:
self.t.delete('1.0', 'end')
self.t.insert('insert', c+predic)
def insert_end(self):
self.t.delete('1.0', 'end')
def main(self):
window = tk.Tk()
window.title('classy')
window.geometry('500x500')
self.v1 = tk.StringVar()
self.e = tk.Entry(window, show=None,width=200, textvariable=self.v1)
self.e.pack()
self.k = tk.Text(window, height=2,state = 'disabled')
self.k.pack()
self.k.insert('insert', '输入四个范围为1-6的数字不加任何连接符')
self.t = tk.Text(window, height=2)
self.t.pack()
b1 = tk.Button(window, text='insert point', width=15,
height=2, command=self.insert_point)
b1.pack()
self.v1.set('')
b2 = tk.Button(window, text='insert end',width=15,height=2,
command=self.insert_end)
b2.pack()
window.mainloop()
if __name__ == '__main__':
a = evaclassy()
a.main()
# a.draw() #损失 正确函数画图
b=draw1()
b.main()