You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fitt/fitting.py

157 lines
6.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- encoding: utf-8 -*-
"""
@Author: packy945
@FileName: fitting.py
@DateTime: 2023/5/31 14:16
@SoftWare: PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import pylab as mpl
from tkinter import *
import mpl_toolkits.axisartist as axisartist
import data as gl_data
# import function
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文不显示问题
plt.rcParams['axes.unicode_minus'] = False #解决负数坐标显示问题
def draw_axis(low, high, step=250):
fig = plt.figure(figsize=(4.4, 3.2))
ax = axisartist.Subplot(fig, 111)# 使用axisartist.Subplot方法创建一个绘图区对象ax
fig.add_axes(ax)# 将绘图区对象添加到画布中
ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏
ax.axis["x"] = ax.new_floating_axis(0, 0)# 添加新的x坐标轴
ax.axis["x"].set_axisline_style("-|>", size=1.0)# 给x坐标轴加上箭头
ax.axis["y"] = ax.new_floating_axis(1, 0)# 添加新的y坐标轴
ax.axis["y"].set_axisline_style("-|>", size=1.0)# y坐标轴加上箭头
ax.axis["x"].set_axis_direction("bottom")# 设置x、y轴上刻度显示方向
ax.axis["y"].set_axis_direction("left")# 设置x、y轴上刻度显示方向
plt.xlim(low, high) # 把x轴的刻度范围设置
plt.ylim(low, high) # 把y轴的刻度范围设置
ax.set_xticks(np.arange(low, high + 5, step))# 把x轴的刻度间隔设置
ax.set_yticks(np.arange(low, high + 5, step))# 把y轴的刻度间隔设置
def generate_and_plot_sample_data(low, high):
num_samples = 25#设置样本点数量
gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值
gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值
draw_axis(low, high)#绘制坐标轴
plt.scatter(gl_data.X, gl_data.Y, color='red') # 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w')#保存到本地
plt.close()#清除内存
set_phtot(1)#显示到主界面
# 显示数据集
def selfdata_show(x, y, low, high):# 显示数据集
num_points = len(x) # 样本数据点的数量
colors = np.random.rand(num_points, 3)# 生成随机颜色
draw_axis(low, high)
plt.scatter(x, y, c=colors)# 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w') # 保存为png文件
plt.clf()
set_phtot(1)
def fitting(letters,result):
code_str = '''
def func({}):
return {}
'''.format(letters,result)
return code_str
def create_func(code_str):
# 创建一个空的命名空间
namespace = {}
# 使用exec函数执行字符串代码并指定命名空间为locals
exec(code_str, globals(), namespace)
# 返回命名空间中的函数对象
return namespace['func']
def draw_line(xian_index, low, high, sx, sy):
draw_axis(low, high)#绘制坐标轴
popt = []#初始化curve_fit
pcov = []#初始化curve_fit
letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index]
code_str = fitting(letters, result)
func = create_func(code_str) # 获取当前函数func
# cur = function.Fun[xian_index] # 装载正在选择的函数
# func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
# plt.show()#显示函数图像
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def draw_dots_and_line(xian_index, low, high, sx, sy):
if gl_data.Quadrant:
x = sx.tolist()
y = sy.tolist()
negative_x = [i for i, num in enumerate(x) if num < 0]
negative_y = [i for i, num in enumerate(y) if num < 0]
negative = list(set(negative_x + negative_y))
for index in range(len(negative)):
# print(negative[-index-1])
x.pop(negative[-index - 1])
y.pop(negative[-index - 1])
sx = np.array(x)
sy = np.array(y)
print(sx)
############## begin #####################
draw_axis(low, high)
popt = []
pcov = []
letters, result = gl_data.FITT_SAVE['variable'][xian_index], gl_data.FITT_SAVE['function'][xian_index]
code_str = fitting(letters, result)
func = create_func(code_str) # 获取当前函数func
print(code_str)
# cur = function.Fun[xian_index] # 装载正在选择的函数
# func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
positive_mask = sy >= 0.0#给样本点分类
negative_mask = sy < 0.0#给样本点分类
positive_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类
negative_colors = ['red' if x >= 0.0 else 'red' for x in sx]#给样本点分类
#根据样本点类型决定样本点颜色
plt.scatter(sx[positive_mask], sy[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1)
plt.scatter(sx[negative_mask], sy[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1)
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
############## end #######################
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def set_phtot(index1):
# 显示图像
Canvas2 = gl_data.Canvas2
if index1 == 1:
gl_data.Img1=PhotoImage(file=r"dot.png")
# print(gl_data.Img1)
gl_data.Img1 = gl_data.Img1.zoom(2, 2)
set_img1=Canvas2.create_image(2, 10, image=gl_data.Img1, anchor=NW)
# Canvas2.itemconfigure(set_img1,int(420 * gl_data.MAGNIDICATION), int(300 * gl_data.MAGNIDICATION))
Canvas2.update()
print("已输出数据点")
elif index1 == 2:
gl_data.Img2=PhotoImage(file=r"line.png")
# print(gl_data.Img2)
gl_data.Img2 = gl_data.Img2.zoom(2, 2)
set_img2=Canvas2.create_image(2, 10, image=gl_data.Img2, anchor=NW)
# Canvas2.itemconfigure(set_img2, 0,0,int(420*gl_data.MAGNIDICATION), int(300*gl_data.MAGNIDICATION))
Canvas2.update()
print("已输出数据点和曲线")
if __name__ == '__main__':
# random_points()
pass