You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
1.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import data as gl_data
import numpy as np
import matplotlib.pyplot as plt
def gradient_descent_poly_fit(x, y, degree, learning_rate, iterations):
# 初始化参数多项式系数为0
theta = np.zeros(degree + 1)
m = len(x) # 样本数量
# 归一化x以改善数值稳定性
x_normalized = (x - x.mean()) / x.std()
# 迭代梯度下降
for _ in range(iterations):
# 通过当前参数计算多项式的值
y_pred = np.polyval(theta[::-1], x_normalized)
# 计算预测值与真实值之间的误差
error = y_pred - y
# 对每个参数计算梯度
for i in range(degree + 1):
# 计算损失函数对参数的偏导数(梯度)
gradient = np.dot(error, x_normalized**i) * 2 / m
# 梯度下降更新参数
theta[i] -= learning_rate * gradient
return theta, x_normalized
# 设置一个较小的学习率和较多的迭代次数
learning_rate = 0.001
iterations = 50000
# x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
x = np.array([0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
y = np.array([1, 2, 1, 5, 8, 13, 21, 34, 55, 89])
degree = 2
# 运行梯度下降算法
theta, x_normalized = gradient_descent_poly_fit(x, y, degree, learning_rate, iterations)
# 用拟合的参数计算多项式的值
x_fit_normalized = np.linspace(x_normalized.min(), x_normalized.max(), 100)
y_fit = np.polyval(theta[::-1], x_fit_normalized)
# 反归一化x_fit
x_fit = x_fit_normalized * x.std() + x.mean()
# 绘制原始数据点和拟合的多项式
plt.scatter(x, y, color='red', label='Sample Data')
plt.plot(x_fit, y_fit, label='Polynomial Fit')
plt.legend()
plt.show()
# 输出拟合参数
print(theta)