|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
# Time : 2023/8/9 10:20
|
|
|
|
|
# Author : lirunsheng
|
|
|
|
|
# User : l'r's
|
|
|
|
|
# Software: PyCharm
|
|
|
|
|
# File : X3.py
|
|
|
|
|
|
|
|
|
|
import mpl_toolkits.axisartist as axisartist
|
|
|
|
|
from scipy.optimize import curve_fit
|
|
|
|
|
from pandas import DataFrame
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pylab as mpl
|
|
|
|
|
from tkinter import *
|
|
|
|
|
import data as gl_data
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import tkinter
|
|
|
|
|
import sys
|
|
|
|
|
import tkinter as tk
|
|
|
|
|
from tkinter import filedialog
|
|
|
|
|
|
|
|
|
|
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决matplotlib中文不显示问题
|
|
|
|
|
plt.rcParams['axes.unicode_minus'] = False # 解决matplotlib负数坐标显示问题
|
|
|
|
|
|
|
|
|
|
# 创建一个二维数组sampleData,用于存储样本数据
|
|
|
|
|
sampleData = []
|
|
|
|
|
|
|
|
|
|
def askfile():# 从本地选择一个文件,并返回文件的路径
|
|
|
|
|
global sampleData
|
|
|
|
|
# 弹出文件选择对话框,选择要打开的文本文件
|
|
|
|
|
file_path = filedialog.askopenfilename(filetypes=[('Text Files', '*.txt')])
|
|
|
|
|
|
|
|
|
|
# 如果没有选择文件,直接返回
|
|
|
|
|
if not file_path:
|
|
|
|
|
return
|
|
|
|
|
# 步骤1:打开样本数据集文件并读取每行数据
|
|
|
|
|
with open(file_path, 'r') as file:
|
|
|
|
|
lines = file.readlines()
|
|
|
|
|
|
|
|
|
|
# 步骤2:遍历所有行,将每行数据解析为sx和sy值,转换为浮点数并存储到sampleData中
|
|
|
|
|
for line in lines:
|
|
|
|
|
sx, sy = line.strip().split(' ')
|
|
|
|
|
sx = float(sx)
|
|
|
|
|
sy = float(sy)
|
|
|
|
|
sampleData.append([sx, sy])
|
|
|
|
|
|
|
|
|
|
# # 步骤3:使用循环遍历样本数据并打印出来
|
|
|
|
|
# for data in sampleData:
|
|
|
|
|
# print(data)
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# askfile()
|
|
|
|
|
def draw_axis(low, high, step=250):
|
|
|
|
|
fig = plt.figure(figsize=(4.4, 3.2)) # 设置显示大小
|
|
|
|
|
ax = axisartist.Subplot(fig, 111) # 使用axisartist.Subplot方法创建一个绘图区对象ax
|
|
|
|
|
fig.add_axes(ax) # 将绘图区对象添加到画布中
|
|
|
|
|
ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏
|
|
|
|
|
ax.axis["x"] = ax.new_floating_axis(0, 0) # 添加新的x坐标轴
|
|
|
|
|
ax.axis["x"].set_axisline_style("-|>", size=1.0) # 给x坐标轴加上箭头
|
|
|
|
|
ax.axis["y"] = ax.new_floating_axis(1, 0) # 添加新的y坐标轴
|
|
|
|
|
ax.axis["y"].set_axisline_style("-|>", size=1.0) # y坐标轴加上箭头
|
|
|
|
|
ax.axis["x"].set_axis_direction("bottom") # 设置x、y轴上刻度显示方向
|
|
|
|
|
ax.axis["y"].set_axis_direction("left") # 设置x、y轴上刻度显示方向
|
|
|
|
|
plt.xlim(low, high) # 把x轴的刻度范围设置
|
|
|
|
|
plt.ylim(low, high) # 把y轴的刻度范围设置
|
|
|
|
|
ax.set_xticks(np.arange(low, high + 5, step)) # 把x轴的刻度间隔设置
|
|
|
|
|
ax.set_yticks(np.arange(low, high + 5, step)) # 把y轴的刻度间隔设置
|
|
|
|
|
|
|
|
|
|
def generate_and_plot_sample_data(low, high):
|
|
|
|
|
num_samples = 25 #设置样本点数量
|
|
|
|
|
gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值
|
|
|
|
|
print(gl_data.X)
|
|
|
|
|
gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值
|
|
|
|
|
# 提取样本数据的x坐标和y坐标
|
|
|
|
|
# x_values = [point[0] for point in sampleData]
|
|
|
|
|
# print(x_values)
|
|
|
|
|
# y_values = [point[1] for point in sampleData]
|
|
|
|
|
draw_axis(low, high) #绘制坐标轴
|
|
|
|
|
plt.scatter(gl_data.X,gl_data.Y, color='red') # 绘制样本数据点
|
|
|
|
|
plt.savefig(r"dot2.png", facecolor='w') # 保存到本地
|
|
|
|
|
plt.close() # 清除内存
|
|
|
|
|
# set_phtot(1) #显示到主界面
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# askfile()
|
|
|
|
|
# # generate_and_plot_sample_data(gl_data.LOW, gl_data.HIGH)
|
|
|
|
|
def quadratic_function(x, a, b, c):#构造二次函数y = a * X^2 + b * X + C
|
|
|
|
|
return a * x ** 2 + b * x + c
|
|
|
|
|
def draw_dots_and_line(low, high, sx, sy):
|
|
|
|
|
draw_axis(low, high)
|
|
|
|
|
popt, pcov = curve_fit(quadratic_function, sx, sy)# 用curve_fit来对点进行拟合
|
|
|
|
|
positive_mask = sy >= 0.0#给样本点分类
|
|
|
|
|
negative_mask = sy < 0.0#给样本点分类
|
|
|
|
|
positive_colors = ['red' if x >= 0.0 else 'blue' for x in sx]#给样本点分类
|
|
|
|
|
negative_colors = ['green' if x >= 0.0 else 'purple' for x in sx]#给样本点分类
|
|
|
|
|
#根据样本点类型决定样本点颜色
|
|
|
|
|
plt.scatter(gl_data.X[positive_mask], gl_data.Y[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1)
|
|
|
|
|
plt.scatter(gl_data.X[negative_mask], gl_data.Y[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1)
|
|
|
|
|
curve_x = np.arange(low, high)#按照步长生成的一串数字
|
|
|
|
|
curve_y = [quadratic_function (i, * popt) for i in curve_x]# 根据x来计算y1值
|
|
|
|
|
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
|
|
|
|
|
plt.savefig(r"dot5.png", facecolor='w') # 保存到本地
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.show()
|
|
|
|
|
return popt
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# askfile()
|
|
|
|
|
# gl_data.x = np.array([point[0] for point in sampleData]) # 为二次函数.txt
|
|
|
|
|
# gl_data.y = np.array([point[1] for point in sampleData])
|
|
|
|
|
# popt = draw_dots_and_line(gl_data.low, gl_data.high, gl_data.x, gl_data.y)
|
|
|
|
|
# a,b,c = popt
|
|
|
|
|
# print(a,b,c)
|
|
|
|
|
|
|
|
|
|
def input_num(root_tk):
|
|
|
|
|
global top
|
|
|
|
|
top = tk.Toplevel(root_tk)
|
|
|
|
|
top.geometry("300x50")
|
|
|
|
|
top.title('坐标点个数')
|
|
|
|
|
label1 = Label(top, text="坐标点个数")
|
|
|
|
|
label1.grid(row=0) # 这里的side可以赋值为LEFT RTGHT TOP BOTTOM
|
|
|
|
|
num1 = IntVar()
|
|
|
|
|
entry1 = Entry(top, textvariable=num1)
|
|
|
|
|
num1.set(0)
|
|
|
|
|
entry1.grid(row=0, column=1)
|
|
|
|
|
Label(top, text=" ").grid(row=0, column=3)
|
|
|
|
|
Button(top, text="确定", command=lambda: input_data(root_tk, int(entry1.get()))).grid(row=0, column=3)
|
|
|
|
|
top.mainloop()
|
|
|
|
|
|
|
|
|
|
def add_sample_data():
|
|
|
|
|
global sample_x, sample_y, sample_data
|
|
|
|
|
global entry_x,entry_y,label_status, numx
|
|
|
|
|
try:
|
|
|
|
|
x = float(entry_x.get())
|
|
|
|
|
y = float(entry_y.get())
|
|
|
|
|
except:
|
|
|
|
|
label_status.config(text="输入不合法")
|
|
|
|
|
return
|
|
|
|
|
entry_x.delete(0, tk.END)
|
|
|
|
|
entry_y.delete(0, tk.END)
|
|
|
|
|
if min(x, y) < gl_data.LOW or max(x, y) > gl_data.HIGH:
|
|
|
|
|
label_status.config(text="输入超过范围")
|
|
|
|
|
return
|
|
|
|
|
elif len(sample_data) < numx:
|
|
|
|
|
label_status.config(text="点对已添加")
|
|
|
|
|
sample_data.append((x, y))
|
|
|
|
|
sample_x.append(x)
|
|
|
|
|
sample_y.append(y)
|
|
|
|
|
else:
|
|
|
|
|
label_status.config(text="已达到最大数量")
|
|
|
|
|
def check_sample_data():
|
|
|
|
|
global label_status,numx,sample_x,sample_y,sample_data
|
|
|
|
|
if len(sample_data) == numx:
|
|
|
|
|
label_status.config(text="已达到最大数量")
|
|
|
|
|
gl_data.X = np.array(sample_x)
|
|
|
|
|
gl_data.Y = np.array(sample_y)
|
|
|
|
|
print('已添加', sample_data)
|
|
|
|
|
sys.exit()
|
|
|
|
|
else:
|
|
|
|
|
label_status.config(text="还需输入{}个点对".format(numx - len(sample_data)))
|
|
|
|
|
print(sample_data)
|
|
|
|
|
|
|
|
|
|
def input_data(root_tk, num):
|
|
|
|
|
global top
|
|
|
|
|
global sample_x,sample_y,sample_data
|
|
|
|
|
global numx
|
|
|
|
|
numx = num
|
|
|
|
|
sample_x = []
|
|
|
|
|
sample_y = []
|
|
|
|
|
sample_data = []
|
|
|
|
|
top.destroy()
|
|
|
|
|
top = tk.Toplevel(root_tk)
|
|
|
|
|
top.geometry("300x200")
|
|
|
|
|
top.title('坐标')
|
|
|
|
|
global entry_x, entry_y, label_status
|
|
|
|
|
label_x = tk.Label(top, text="X 值:")
|
|
|
|
|
label_x.pack()
|
|
|
|
|
entry_x = tk.Entry(top)
|
|
|
|
|
entry_x.pack()
|
|
|
|
|
label_y = tk.Label(top, text="Y 值:")
|
|
|
|
|
label_y.pack()
|
|
|
|
|
entry_y = tk.Entry(top)
|
|
|
|
|
entry_y.pack()
|
|
|
|
|
button_add = tk.Button(top, text="添加", command=add_sample_data)
|
|
|
|
|
button_add.pack()
|
|
|
|
|
button_check = tk.Button(top, text="检查", command=check_sample_data)
|
|
|
|
|
button_check.pack()
|
|
|
|
|
label_status = tk.Label(top, text="")
|
|
|
|
|
label_status.pack()
|
|
|
|
|
top.mainloop()
|
|
|
|
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# root = tk.Tk()
|
|
|
|
|
# root.withdraw()
|
|
|
|
|
# input_num(root)
|
|
|
|
|
|
|
|
|
|
# 定义误差函数
|
|
|
|
|
def error_func(x_data, y_data, a, b, c):
|
|
|
|
|
y_pred = exponential_function(x_data, a, b, c)
|
|
|
|
|
error = np.sum((y_data - y_pred) ** 2)
|
|
|
|
|
residual = y_data - y_pred
|
|
|
|
|
ss_res = np.sum(residual ** 2)
|
|
|
|
|
ss_tot = np.sum((y_data - np.mean(y_data)) ** 2)
|
|
|
|
|
r_squared = 1 - (ss_res / ss_tot)
|
|
|
|
|
return error,r_squared
|
|
|
|
|
|
|
|
|
|
# 给出一个二次函数拟合的例子,读者可自己写拟合函数
|
|
|
|
|
def exponential_function(x, a, b, c):
|
|
|
|
|
return a * x ** 2 + b * x + c
|
|
|
|
|
|
|
|
|
|
def fit(sample_data): # 随机生成样本数据
|
|
|
|
|
x_data = np.array([data[0] for data in sample_data])
|
|
|
|
|
y_data = np.array([data[1] for data in sample_data]) # 拟合样本数据
|
|
|
|
|
popt, pcov = curve_fit(exponential_function, x_data, y_data) # 计算拟合结果
|
|
|
|
|
# 计算拟合误差
|
|
|
|
|
error,r_squared = error_func(x_data, y_data, *popt)
|
|
|
|
|
fit_function_name = exponential_function.__name__# 打印拟合函数的形式、系数、误差和优度
|
|
|
|
|
print("拟合函数形式:{}".format(fit_function_name))
|
|
|
|
|
print("拟合系数:{}".format(popt))
|
|
|
|
|
print("误差:{:.4f}".format(error))
|
|
|
|
|
print("拟合优度(R^2):{:.4f}".format(r_squared))
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# askfile()
|
|
|
|
|
# fit(sampleData) # 对当前x、y值进行拟合
|
|
|
|
|
# # 给出一个指数函数拟合的例子,读者可自己写拟合函数
|
|
|
|
|
# def fit_function(x,a,b,c):
|
|
|
|
|
# return a * np.exp(b * x) + c
|
|
|
|
|
#
|
|
|
|
|
# def fit(sample_data): # 随机生成样本数据
|
|
|
|
|
# x_data = np.array([data[0] for data in sample_data])
|
|
|
|
|
# y_data = np.array([data[1] for data in sample_data]) # 拟合样本数据
|
|
|
|
|
# popt, pcov = curve_fit(fit_function, x_data, y_data) # 计算拟合结果
|
|
|
|
|
# y_fit = fit_function(x_data, *popt)
|
|
|
|
|
# residuals = y_data - y_fit
|
|
|
|
|
# ss_res = np.sum(residuals**2)
|
|
|
|
|
# ss_tot = np.sum((y_data - np.mean(y_data))**2)
|
|
|
|
|
# r_squared = 1 - (ss_res / ss_tot)
|
|
|
|
|
# fit_function_name = fit_function.__name__# 打印拟合函数的形式、系数、误差和优度
|
|
|
|
|
# print("拟合函数形式:{}".format(fit_function_name))
|
|
|
|
|
# print("拟合系数:{}".format(popt))
|
|
|
|
|
# print("误差:{:.4f}".format(np.sqrt(np.mean(residuals**2))))
|
|
|
|
|
# print("拟合优度(R^2):{:.4f}".format(r_squared))
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# askfile()
|
|
|
|
|
# fit(sampleData) # 对当前x、y值进行拟合
|
|
|
|
|
|
|
|
|
|
def change_Q(no):
|
|
|
|
|
gl_data.Quadrant = no #更改全局变量的象限显示
|
|
|
|
|
if no: #若为一象限,则修改显示下限为0
|
|
|
|
|
gl_data.LOW = 0
|
|
|
|
|
else: #若为四象限,则修改显示下限为-gl_data.MAXV
|
|
|
|
|
gl_data.LOW = -gl_data.MAXV
|
|
|
|
|
q_button() #更新象限显示面板
|
|
|
|
|
|
|
|
|
|
def q_button():
|
|
|
|
|
r = 7.5
|
|
|
|
|
rr = 2.5
|
|
|
|
|
for widget in q_root.winfo_children():
|
|
|
|
|
widget.destroy()
|
|
|
|
|
q_cv = tk.Canvas(q_root, width=450, height=500)
|
|
|
|
|
q_cv.place(x=0, y=0)
|
|
|
|
|
l = tk.Label(q_root, text='坐标轴', bd=0, font=("微软雅黑", 16) , anchor=W)
|
|
|
|
|
l.place(x=20, y=0, width=80, height=50,)
|
|
|
|
|
# 四象b限按钮
|
|
|
|
|
b1 = tk.Button(q_root, text='四象限', bd=0, font=("微软雅黑", 16)
|
|
|
|
|
, command=lambda: change_Q(0), anchor=W)
|
|
|
|
|
b1.place(x=170, y=0, width=80, height=50,)
|
|
|
|
|
# 一象限按钮
|
|
|
|
|
b2 = tk.Button(q_root, text='一象限', bd=0, font=("微软雅黑", 16)
|
|
|
|
|
, command=lambda: change_Q(1), anchor=W)
|
|
|
|
|
b2.place(x=320, y=0, width=80, height=50,)
|
|
|
|
|
# 绘制标记框
|
|
|
|
|
q_cv.create_oval(140 - r, 25 - r, 140 + r, 25 + r, fill="white", width=1, outline="black")
|
|
|
|
|
q_cv.create_oval(290 - r, 25 - r, 290 + r, 25 + r , fill="white", width=1, outline="black")
|
|
|
|
|
# 根据当前的象限选择值来填充标记框
|
|
|
|
|
if gl_data.Quadrant == 0:
|
|
|
|
|
q_cv.create_oval(140 - rr, 25 - rr, 140 + rr, 25 + rr, fill="black", width=1, outline="black")# 绘制象限
|
|
|
|
|
# Q_cv.create_rectangle(80, 80, 270, 270, fill="white") # 第一象限
|
|
|
|
|
# Q_cv.create_rectangle(270, 80, 500, 270, fill="red") # 第二象限
|
|
|
|
|
# Q_cv.create_rectangle(80, 270, 270, 480, fill="blue") # 第三象限
|
|
|
|
|
# Q_cv.create_rectangle(270, 270, 500, 480, fill="green") # 第四象限
|
|
|
|
|
else:
|
|
|
|
|
q_cv.create_oval(290 - rr, 25 - rr, 290 + rr, 25 + rr, fill="black", width=1, outline="black")
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# # 象限选择相关界面
|
|
|
|
|
# q_root = tk.Tk()
|
|
|
|
|
# q_root.geometry("500x550") # 设置窗口的大小和位置
|
|
|
|
|
# q_button()
|
|
|
|
|
# q_root.mainloop()
|
|
|
|
|
|
|
|
|
|
# 定义模型为y=ax+b
|
|
|
|
|
def model(x, a, b):
|
|
|
|
|
return a*x + b
|
|
|
|
|
|
|
|
|
|
# 定义误差函数error_function()
|
|
|
|
|
def error_function(x,y,a, b):
|
|
|
|
|
y_pred = model(x, a, b)
|
|
|
|
|
error = np.mean((y - y_pred)**2)
|
|
|
|
|
return error
|
|
|
|
|
|
|
|
|
|
# 定义梯度函数gradient()
|
|
|
|
|
def gradient(a, b,x,y):
|
|
|
|
|
y_pred = model(x, a, b)
|
|
|
|
|
gradient_a = 2*np.mean((y_pred - y)*x)
|
|
|
|
|
gradient_b = 2*np.mean(y_pred - y)
|
|
|
|
|
return gradient_a, gradient_b
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
# 读取sampledata值
|
|
|
|
|
sampledata = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
|
|
|
x = np.array(sampledata)
|
|
|
|
|
y = np.array([-2, 1, 0, 1, 2, 3, 4, 5, 6, 8])
|
|
|
|
|
# 初始化参数
|
|
|
|
|
a = 0
|
|
|
|
|
b = 0
|
|
|
|
|
learning_rate = 0.01
|
|
|
|
|
num_iterations = 1000
|
|
|
|
|
# 梯度下降算法
|
|
|
|
|
for i in range(num_iterations):
|
|
|
|
|
grad_a, grad_b = gradient(a, b,x,y)
|
|
|
|
|
a -= learning_rate * grad_a
|
|
|
|
|
b -= learning_rate * grad_b
|
|
|
|
|
# 输出最终拟合结果
|
|
|
|
|
final_error = error_function(x,y,a, b)
|
|
|
|
|
# gl_data.X = x # 为二次函数.txt
|
|
|
|
|
# gl_data.Y = y
|
|
|
|
|
# draw_axis(-100, 100)
|
|
|
|
|
# positive_mask = x >= 0.0 # 给样本点分类
|
|
|
|
|
# negative_mask = y < 0.0 # 给样本点分类
|
|
|
|
|
# positive_colors = ['red' if x >= 0.0 else 'blue' for x in x] # 给样本点分类
|
|
|
|
|
# negative_colors = ['green' if x >= 0.0 else 'purple' for x in x] # 给样本点分类
|
|
|
|
|
# # 根据样本点类型决定样本点颜色
|
|
|
|
|
# plt.scatter(gl_data.X[positive_mask], gl_data.Y[positive_mask], color=np.array(positive_colors)[positive_mask],
|
|
|
|
|
# lw=1)
|
|
|
|
|
# plt.scatter(gl_data.X[negative_mask], gl_data.Y[negative_mask], color=np.array(negative_colors)[negative_mask],
|
|
|
|
|
# lw=1)
|
|
|
|
|
# curve_x = np.arange(-1000, 1000) # 按照步长生成的一串数字
|
|
|
|
|
# curve_y = [model(i, a, b) for i in curve_x] # 根据x来计算y1值
|
|
|
|
|
# plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve') # 绘制拟合曲线
|
|
|
|
|
# plt.legend()
|
|
|
|
|
# plt.show()
|
|
|
|
|
print("拟合函数形式: y = ax + b")
|
|
|
|
|
print("系数 a:{:.2f}".format(a))
|
|
|
|
|
print("系数 b:{:.2f}".format( b))
|
|
|
|
|
print("误差计算方法: 均方误差")
|
|
|
|
|
print("最终误差:{:.2f}".format(final_error))
|