You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fitt/fitting.py

118 lines
5.1 KiB

# -*- encoding: utf-8 -*-
"""
@Author: packy945
@FileName: fitting.py
@DateTime: 2023/5/31 14:16
@SoftWare: PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import pylab as mpl
from tkinter import *
import mpl_toolkits.axisartist as axisartist
import data as gl_data
import function
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文不显示问题
plt.rcParams['axes.unicode_minus'] = False #解决负数坐标显示问题
def draw_axis(low, high, step=250):
fig = plt.figure(figsize=(4.4, 3.2))
ax = axisartist.Subplot(fig, 111)# 使用axisartist.Subplot方法创建一个绘图区对象ax
fig.add_axes(ax)# 将绘图区对象添加到画布中
ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏
ax.axis["x"] = ax.new_floating_axis(0, 0)# 添加新的x坐标轴
ax.axis["x"].set_axisline_style("-|>", size=1.0)# 给x坐标轴加上箭头
ax.axis["y"] = ax.new_floating_axis(1, 0)# 添加新的y坐标轴
ax.axis["y"].set_axisline_style("-|>", size=1.0)# y坐标轴加上箭头
ax.axis["x"].set_axis_direction("bottom")# 设置x、y轴上刻度显示方向
ax.axis["y"].set_axis_direction("left")# 设置x、y轴上刻度显示方向
plt.xlim(low, high) # 把x轴的刻度范围设置
plt.ylim(low, high) # 把y轴的刻度范围设置
ax.set_xticks(np.arange(low, high + 5, step))# 把x轴的刻度间隔设置
ax.set_yticks(np.arange(low, high + 5, step))# 把y轴的刻度间隔设置
def generate_and_plot_sample_data(low, high):
num_samples = 25#设置样本点数量
gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值
gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值
draw_axis(low, high)#绘制坐标轴
plt.scatter(gl_data.X, gl_data.Y, color='red') # 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w')#保存到本地
plt.close()#清除内存
set_phtot(1)#显示到主界面
# 显示数据集
def selfdata_show(x, y, low, high):# 显示数据集
num_points = len(x) # 样本数据点的数量
colors = np.random.rand(num_points, 3)# 生成随机颜色
draw_axis(low, high)
plt.scatter(x, y, c=colors)# 绘制样本数据点
plt.savefig(r"dot.png", facecolor='w') # 保存为png文件
plt.clf()
set_phtot(1)
def draw_line(xian_index, low, high, sx, sy):
draw_axis(low, high)#绘制坐标轴
popt = []#初始化curve_fit
pcov = []#初始化curve_fit
cur = function.Fun[xian_index] # 装载正在选择的函数
func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
# plt.show()#显示函数图像
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def draw_dots_and_line(xian_index, low, high, sx, sy):
############## begin #####################
draw_axis(low, high)
popt = []
pcov = []
cur = function.Fun[xian_index] # 装载正在选择的函数
func = cur.get_fun()#获取当前函数func
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
positive_mask = sy >= 0.0#给样本点分类
negative_mask = sy < 0.0#给样本点分类
positive_colors = ['red' if x >= 0.0 else 'blue' for x in sx]#给样本点分类
negative_colors = ['green' if x >= 0.0 else 'purple' for x in sx]#给样本点分类
#根据样本点类型决定样本点颜色
plt.scatter(gl_data.X[positive_mask], gl_data.Y[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1)
plt.scatter(gl_data.X[negative_mask], gl_data.Y[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1)
curve_x = np.arange(low, high)#按照步长生成的一串数字
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
############## end #######################
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
plt.legend()
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
plt.close()#清除内存
set_phtot(2)#将图片显示到程序中
def set_phtot(index1):
# 显示图像
Canvas2 = gl_data.Canvas2
if index1 == 1:
gl_data.Img1=PhotoImage(file=r"dot.png")
# print(gl_data.Img1)
set_img1=Canvas2.create_image(2, 10, image=gl_data.Img1, anchor=NW)
Canvas2.update()
print("已输出数据点")
elif index1 == 2:
gl_data.Img2=PhotoImage(file=r"line.png")
# print(gl_data.Img2)
set_img2=Canvas2.create_image(2, 10, image=gl_data.Img2, anchor=NW)
Canvas2.update()
print("已输出数据点和曲线")
if __name__ == '__main__':
# random_points()
pass