You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
5.1 KiB
118 lines
5.1 KiB
# -*- encoding: utf-8 -*-
|
|
"""
|
|
@Author: packy945
|
|
@FileName: fitting.py
|
|
@DateTime: 2023/5/31 14:16
|
|
@SoftWare: PyCharm
|
|
"""
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from scipy.optimize import curve_fit
|
|
import pylab as mpl
|
|
from tkinter import *
|
|
import mpl_toolkits.axisartist as axisartist
|
|
import data as gl_data
|
|
import function
|
|
|
|
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文不显示问题
|
|
plt.rcParams['axes.unicode_minus'] = False #解决负数坐标显示问题
|
|
|
|
def draw_axis(low, high, step=250):
|
|
fig = plt.figure(figsize=(4.4, 3.2))
|
|
ax = axisartist.Subplot(fig, 111)# 使用axisartist.Subplot方法创建一个绘图区对象ax
|
|
fig.add_axes(ax)# 将绘图区对象添加到画布中
|
|
ax.axis[:].set_visible(False)# 通过set_visible方法设置绘图区所有坐标轴隐藏
|
|
ax.axis["x"] = ax.new_floating_axis(0, 0)# 添加新的x坐标轴
|
|
ax.axis["x"].set_axisline_style("-|>", size=1.0)# 给x坐标轴加上箭头
|
|
ax.axis["y"] = ax.new_floating_axis(1, 0)# 添加新的y坐标轴
|
|
ax.axis["y"].set_axisline_style("-|>", size=1.0)# y坐标轴加上箭头
|
|
ax.axis["x"].set_axis_direction("bottom")# 设置x、y轴上刻度显示方向
|
|
ax.axis["y"].set_axis_direction("left")# 设置x、y轴上刻度显示方向
|
|
plt.xlim(low, high) # 把x轴的刻度范围设置
|
|
plt.ylim(low, high) # 把y轴的刻度范围设置
|
|
ax.set_xticks(np.arange(low, high + 5, step))# 把x轴的刻度间隔设置
|
|
ax.set_yticks(np.arange(low, high + 5, step))# 把y轴的刻度间隔设置
|
|
|
|
def generate_and_plot_sample_data(low, high):
|
|
num_samples = 25#设置样本点数量
|
|
gl_data.X = np.random.randint(low=low, high=high, size=num_samples)#随机生成x值
|
|
gl_data.Y = np.random.randint(low=low, high=high, size=num_samples)#随机生成y值
|
|
draw_axis(low, high)#绘制坐标轴
|
|
plt.scatter(gl_data.X, gl_data.Y, color='red') # 绘制样本数据点
|
|
plt.savefig(r"dot.png", facecolor='w')#保存到本地
|
|
plt.close()#清除内存
|
|
set_phtot(1)#显示到主界面
|
|
|
|
# 显示数据集
|
|
def selfdata_show(x, y, low, high):# 显示数据集
|
|
num_points = len(x) # 样本数据点的数量
|
|
colors = np.random.rand(num_points, 3)# 生成随机颜色
|
|
draw_axis(low, high)
|
|
plt.scatter(x, y, c=colors)# 绘制样本数据点
|
|
plt.savefig(r"dot.png", facecolor='w') # 保存为png文件
|
|
plt.clf()
|
|
set_phtot(1)
|
|
|
|
def draw_line(xian_index, low, high, sx, sy):
|
|
draw_axis(low, high)#绘制坐标轴
|
|
popt = []#初始化curve_fit
|
|
pcov = []#初始化curve_fit
|
|
cur = function.Fun[xian_index] # 装载正在选择的函数
|
|
func = cur.get_fun()#获取当前函数func
|
|
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
|
|
curve_x = np.arange(low, high)#按照步长生成的一串数字
|
|
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
|
|
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
|
|
plt.legend()
|
|
# plt.show()#显示函数图像
|
|
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
|
|
plt.close()#清除内存
|
|
set_phtot(2)#将图片显示到程序中
|
|
|
|
def draw_dots_and_line(xian_index, low, high, sx, sy):
|
|
############## begin #####################
|
|
draw_axis(low, high)
|
|
popt = []
|
|
pcov = []
|
|
cur = function.Fun[xian_index] # 装载正在选择的函数
|
|
func = cur.get_fun()#获取当前函数func
|
|
popt, pcov = curve_fit(func, sx, sy)# 用curve_fit来对点进行拟合
|
|
positive_mask = sy >= 0.0#给样本点分类
|
|
negative_mask = sy < 0.0#给样本点分类
|
|
positive_colors = ['red' if x >= 0.0 else 'blue' for x in sx]#给样本点分类
|
|
negative_colors = ['green' if x >= 0.0 else 'purple' for x in sx]#给样本点分类
|
|
#根据样本点类型决定样本点颜色
|
|
plt.scatter(gl_data.X[positive_mask], gl_data.Y[positive_mask], color=np.array(positive_colors)[positive_mask], lw=1)
|
|
plt.scatter(gl_data.X[negative_mask], gl_data.Y[negative_mask], color=np.array(negative_colors)[negative_mask], lw=1)
|
|
curve_x = np.arange(low, high)#按照步长生成的一串数字
|
|
curve_y = [func(i, *popt) for i in curve_x]# 根据x0(按照步长生成的一串数字)来计算y1值
|
|
############## end #######################
|
|
plt.plot(curve_x, curve_y, color='blue', label='Fitted Curve')#绘制拟合曲线
|
|
plt.legend()
|
|
plt.savefig(r"line.png", facecolor='w')#将图片保存到本地
|
|
plt.close()#清除内存
|
|
set_phtot(2)#将图片显示到程序中
|
|
|
|
|
|
|
|
def set_phtot(index1):
|
|
# 显示图像
|
|
Canvas2 = gl_data.Canvas2
|
|
if index1 == 1:
|
|
gl_data.Img1=PhotoImage(file=r"dot.png")
|
|
# print(gl_data.Img1)
|
|
set_img1=Canvas2.create_image(2, 10, image=gl_data.Img1, anchor=NW)
|
|
Canvas2.update()
|
|
print("已输出数据点")
|
|
elif index1 == 2:
|
|
gl_data.Img2=PhotoImage(file=r"line.png")
|
|
# print(gl_data.Img2)
|
|
set_img2=Canvas2.create_image(2, 10, image=gl_data.Img2, anchor=NW)
|
|
Canvas2.update()
|
|
print("已输出数据点和曲线")
|
|
if __name__ == '__main__':
|
|
# random_points()
|
|
|
|
pass
|