# -*- encoding: utf-8 -*- """ @Author: packy945 @FileName: fitting.py @DateTime: 2023/5/31 14:16 @SoftWare: PyCharm """ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit import pylab as mpl from tkinter import * import mpl_toolkits.axisartist as axisartist import data as gl_data # import function mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文不显示问题 plt.rcParams['axes.unicode_minus'] = False #解决负数坐标显示问题 def draw_axis(low, high, step=250): fig = plt.figure(figsize=(4.4, 3.2)) ax = axisartist.Subplot(fig, 111)# 使用axisartist.Subplot方法创建一个绘图区对象ax fig.add_axes(ax)# 将绘图区对象添加到画布中 ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏 ax.axis["x"] = ax.new_floating_axis(0, 0)# 添加新的x坐标轴 ax.axis["x"].set_axisline_style("-|>", size=1.0)# 给x坐标轴加上箭头 ax.axis["y"] = ax.new_floating_axis(1, 0)# 添加新的y坐标轴 ax.axis["y"].set_axisline_style("-|>", size=1.0)# y坐标轴加上箭头 ax.axis["x"].set_axis_direction("bottom")# 设置x、y轴上刻度显示方向 ax.axis["y"].set_axis_direction("left")# 设置x、y轴上刻度显示方向 plt.xlim(low, high) # 把x轴的刻度范围设置 plt.ylim(low, high) # 把y轴的刻度范围设置 ax.set_xticks(np.arange(low, high + 5, step))# 把x轴的刻度间隔设置 ax.set_yticks(np.arange(low, high + 5, step))# 把y轴的刻度间隔设置 def generate_and_plot_sample_data(low, high): num_samples = 25#设置样本点数量 gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值 gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值 draw_axis(low, high)#绘制坐标轴 plt.scatter(gl_data.X, gl_data.Y, color='red') # 绘制样本数据点 plt.savefig(r"dot.png", facecolor='w')#保存到本地 plt.close()#清除内存 set_phtot(1)#显示到主界面 # 显示数据集 def selfdata_show(x, y, low, high):# 显示数据集 num_points = len(x) # 样本数据点的数量 colors = np.random.rand(num_points, 3)# 生成随机颜色 draw_axis(low, high) plt.scatter(x, y, c=colors)# 绘制样本数据点 plt.savefig(r"dot.png", facecolor='w') # 保存为png文件 plt.clf() set_phtot(1) def fitting(letters,result): code_str = ''' def func({}): return {} '''.format(letters,result) return code_str def create_func(code_str): # 创建一个空的命名空间 namespace = {} # 使用exec函数执行字符串代码,并指定命名空间为locals exec(code_str, globals(), namespace) # 返回命名空间中的函数对象 return namespace['func'] def draw_line(xian_index, low, high, sx, sy): draw_axis(low, high)#绘制坐标轴 popt = []#初始化curve_fit pcov = []#初始化curve_fit letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index] code_str = fitting(letters, result) func = create_func(code_str) # 获取当前函数func # cur = function.Fun[xian_index] # 装载正在选择的函数 # func = cur.get_fun()#获取当前函数func popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合 curve_x = np.arange(low, high)#按照步长生成的一串数字 curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值 plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线 plt.legend() # plt.show()#显示函数图像 plt.savefig(r"line.png", facecolor='w')#将图片保存到本地 plt.close()#清除内存 set_phtot(2)#将图片显示到程序中 def draw_dots_and_line(xian_index, low, high, sx, sy): if gl_data.Quadrant: x = sx.tolist() y = sy.tolist() negative_x = [i for i, num in enumerate(x) if num < 0] negative_y = [i for i, num in enumerate(y) if num < 0] negative = list(set(negative_x + negative_y)) for index in range(len(negative)): # print(negative[-index-1]) x.pop(negative[-index - 1]) y.pop(negative[-index - 1]) sx = np.array(x) sy = np.array(y) print(sx) ############## begin ##################### draw_axis(low, high) popt = [] pcov = [] letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index] code_str = fitting(letters, result) func = create_func(code_str) # 获取当前函数func print(code_str) # cur = function.Fun[xian_index] # 装载正在选择的函数 # func = cur.get_fun()#获取当前函数func popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合 positive_mask = sy >= 0.0#给样本点分类 negative_mask = sy < 0.0#给样本点分类 positive_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类 negative_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类 #根据样本点类型决定样本点颜色 plt.scatter(sx[positive_mask], sy[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1) plt.scatter(sx[negative_mask], sy[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1) curve_x = np.arange(low, high)#按照步长生成的一串数字 curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值 ############## end ####################### plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线 plt.legend() plt.savefig(r"line.png", facecolor='w')#将图片保存到本地 plt.close()#清除内存 set_phtot(2)#将图片显示到程序中 def set_phtot(index1): # 显示图像 Canvas2 = gl_data.Canvas2 if index1 == 1: gl_data.Img1=PhotoImage(file=r"dot.png") # print(gl_data.Img1) gl_data.Img1 = gl_data.Img1.zoom(2, 2) set_img1=Canvas2.create_image(2, 10, image=gl_data.Img1, anchor=NW) # Canvas2.itemconfigure(set_img1,int(420 * gl_data.MAGNIDICATION), int(300 * gl_data.MAGNIDICATION)) Canvas2.update() print("已输出数据点") elif index1 == 2: gl_data.Img2=PhotoImage(file=r"line.png") # print(gl_data.Img2) gl_data.Img2 = gl_data.Img2.zoom(2, 2) set_img2=Canvas2.create_image(2, 10, image=gl_data.Img2, anchor=NW) # Canvas2.itemconfigure(set_img2, 0,0,int(420*gl_data.MAGNIDICATION), int(300*gl_data.MAGNIDICATION)) Canvas2.update() print("已输出数据点和曲线") if __name__ == '__main__': # random_points() pass