import numpy as np import matplotlib.pyplot as plt def gradient_descent_poly_fit(sampleData, degree, learning_rate=0.001, iterations=50000): def error_function(y_pred, y): # 定义误差函数error_function() return y_pred - y def gradient(degree,error,x_normalized,sampleNum,learning_rate): # 定义梯度函数gradient() # 对每个参数计算梯度 for i in range(degree + 1): # 计算损失函数对参数的偏导数(梯度) gradient = np.dot(error, x_normalized**i) * 2 / sampleNum # 梯度下降更新参数 theta[i] -= learning_rate * gradient return theta x = sampleData[:,0] y = sampleData[:,1] # 初始化参数(多项式系数)为0 theta = np.zeros(degree + 1) sampleNum = len(x) # 样本数量 # 归一化x以改善数值稳定性 x_normalized = (x - x.mean()) / x.std() # 迭代梯度下降 for _ in range(iterations): # 通过当前参数计算多项式的值 y_pred = np.polyval(theta[::-1], x_normalized) # 计算预测值与真实值之间的误差 error = error_function(y_pred, y) # 计算梯度并更新theta theta = gradient(degree,error,x_normalized,sampleNum,learning_rate) return theta, x_normalized # 设置一个较小的学习率和较多的迭代次数 learning_rate = 0.001 iterations = 50000 degree = 4 x = np.array([0, 100, 200, 300, 400, 500, 600, 700, 800, 900]) y = np.array([10, 20, 10, 50, 80, 130, 210, 340, 550, 890]) sampleData = np.array(list(zip(x, y))) # 运行梯度下降算法 theta, x_normalized = gradient_descent_poly_fit(sampleData, degree, learning_rate, iterations) # 用拟合的参数计算多项式的值 x_fit_normalized = np.linspace(x_normalized.min(), x_normalized.max(), 100) y_fit = np.polyval(theta[::-1], x_fit_normalized) # 反归一化x_fit x_fit = x_fit_normalized * x.std() + x.mean() # 绘制原始数据点和拟合的多项式 plt.scatter(x, y, color='red', label='Sample Data') plt.plot(x_fit, y_fit, label='Polynomial Fit') plt.legend() plt.show() # 输出拟合参数 print(theta) # if __name__ == '__main__': # LimitNum = 1000 # sampleData = random_points(20,0, 1000) # # 这里为了便于观察使用设计好的数据 # # x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # # y = np.array([1, 2, 1, 5, 8, 13, 21, 34, 55, 89]) # # sampleData = np.array(list(zip(x, y))) # 将两个一维数组拼接成二维数组 # m = 3 # theta, covariance_matrix = least_square_method(m, sampleData) # # print(theta) # curveData = compute_curveData(0,1000,1,theta,m) # # print("curveData",curveData) # draw_dots_and_line(curveData,sampleData,0,1000)