You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fitt/fitting.py

157 lines
6.7 KiB

# -*- encoding: utf-8 -*-
"""
@Author: packy945
@FileName: fitting.py
@DateTime: 2023/5/31 14:16
@SoftWare: PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import pylab as mpl
from tkinter import *
import mpl_toolkits.axisartist as axisartist
import data as gl_data
11 months ago
# import function
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文不显示问题
plt.rcParams['axes.unicode_minus'] = False #解决负数坐标显示问题
def draw_axis(low, high, step=250):
fig = plt.figure(figsize=(4.4, 3.2))
ax = axisartist.Subplot(fig, 111)# 使用axisartist.Subplot方法创建一个绘图区对象ax
fig.add_axes(ax)# 将绘图区对象添加到画布中
ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏
ax.axis["x"] = ax.new_floating_axis(0, 0)# 添加新的x坐标轴
ax.axis["x"].set_axisline_style("-|>", size=1.0)# 给x坐标轴加上箭头
ax.axis["y"] = ax.new_floating_axis(1, 0)# 添加新的y坐标轴
ax.axis["y"].set_axisline_style("-|>", size=1.0)# y坐标轴加上箭头
ax.axis["x"].set_axis_direction("bottom")# 设置x、y轴上刻度显示方向
ax.axis["y"].set_axis_direction("left")# 设置x、y轴上刻度显示方向
plt.xlim(low, high) # 把x轴的刻度范围设置
plt.ylim(low, high) # 把y轴的刻度范围设置
ax.set_xticks(np.arange(low, high + 5, step))# 把x轴的刻度间隔设置
ax.set_yticks(np.arange(low, high + 5, step))# 把y轴的刻度间隔设置
def generate_and_plot_sample_data(low, high):
num_samples = 25#设置样本点数量
gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值
gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值
draw_axis(low, high)#绘制坐标轴
plt.scatter(gl_data.X, gl_data.Y, color='red') # 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w')#保存到本地
plt.close()#清除内存
set_phtot(1)#显示到主界面
# 显示数据集
def selfdata_show(x, y, low, high):# 显示数据集
num_points = len(x) # 样本数据点的数量
colors = np.random.rand(num_points, 3)# 生成随机颜色
draw_axis(low, high)
plt.scatter(x, y, c=colors)# 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w') # 保存为png文件
plt.clf()
set_phtot(1)
11 months ago
def fitting(letters,result):
code_str = '''
def func({}):
return {}
'''.format(letters,result)
return code_str
def create_func(code_str):
# 创建一个空的命名空间
namespace = {}
# 使用exec函数执行字符串代码并指定命名空间为locals
exec(code_str, globals(), namespace)
# 返回命名空间中的函数对象
return namespace['func']
def draw_line(xian_index, low, high, sx, sy):
draw_axis(low, high)#绘制坐标轴
popt = []#初始化curve_fit
pcov = []#初始化curve_fit
11 months ago
letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index]
code_str = fitting(letters, result)
func = create_func(code_str) # 获取当前函数func
# cur = function.Fun[xian_index] # 装载正在选择的函数
# func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
# plt.show()#显示函数图像
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def draw_dots_and_line(xian_index, low, high, sx, sy):
11 months ago
if gl_data.Quadrant:
x = sx.tolist()
y = sy.tolist()
negative_x = [i for i, num in enumerate(x) if num < 0]
negative_y = [i for i, num in enumerate(y) if num < 0]
negative = list(set(negative_x + negative_y))
for index in range(len(negative)):
# print(negative[-index-1])
x.pop(negative[-index - 1])
y.pop(negative[-index - 1])
sx = np.array(x)
sy = np.array(y)
print(sx)
############## begin #####################
draw_axis(low, high)
popt = []
pcov = []
11 months ago
letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index]
code_str = fitting(letters, result)
func = create_func(code_str) # 获取当前函数func
print(code_str)
# cur = function.Fun[xian_index] # 装载正在选择的函数
# func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
positive_mask = sy >= 0.0#给样本点分类
negative_mask = sy < 0.0#给样本点分类
11 months ago
positive_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类
negative_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类
#根据样本点类型决定样本点颜色
11 months ago
plt.scatter(sx[positive_mask], sy[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1)
plt.scatter(sx[negative_mask], sy[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1)
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
############## end #######################
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def set_phtot(index1):
# 显示图像
Canvas2 = gl_data.Canvas2
if index1 == 1:
gl_data.Img1=PhotoImage(file=r"dot.png")
# print(gl_data.Img1)
11 months ago
gl_data.Img1 = gl_data.Img1.zoom(2, 2)
set_img1=Canvas2.create_image(2, 10, image=gl_data.Img1, anchor=NW)
11 months ago
# Canvas2.itemconfigure(set_img1,int(420 * gl_data.MAGNIDICATION), int(300 * gl_data.MAGNIDICATION))
Canvas2.update()
print("已输出数据点")
elif index1 == 2:
gl_data.Img2=PhotoImage(file=r"line.png")
# print(gl_data.Img2)
11 months ago
gl_data.Img2 = gl_data.Img2.zoom(2, 2)
set_img2=Canvas2.create_image(2, 10, image=gl_data.Img2, anchor=NW)
11 months ago
# Canvas2.itemconfigure(set_img2, 0,0,int(420*gl_data.MAGNIDICATION), int(300*gl_data.MAGNIDICATION))
Canvas2.update()
print("已输出数据点和曲线")
if __name__ == '__main__':
# random_points()
pass