You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

50 lines
1.4 KiB

5 months ago
import numpy as np
from scipy.optimize import approx_fprime
import matplotlib.pyplot as plt
# 调整梯度下降法以动态适应参数数量
def gradient_descent_adjusted(func, sx, sy, learning_rate=0.001, iterations=5000):
# 动态确定func需要的参数数量减去sx
params_count = func.__code__.co_argcount - 1
# 基于参数数量初始化theta
theta = np.random.randn(params_count) * 0.01
# 定义损失函数和梯度的计算
def loss_grad(theta):
pred = func(sx, *theta)
error = pred - sy
loss = np.mean(error ** 2)
grad = approx_fprime(theta, lambda t: np.mean((func(sx, *t) - sy) ** 2), epsilon=1e-6)
return loss, grad
# 梯度下降循环
for _ in range(iterations):
_, grad = loss_grad(theta)
theta -= learning_rate * grad
return theta
# 示例函数:简单的线性函数
def example_func(sx,c, a, b):
return c*sx**2 + a * sx + b
# 生成模拟数据
np.random.seed(0)
sx = np.linspace(-1, 1, 100)
sy = 0.1*sx**2+ 3 * sx + 2 + np.random.randn(100) * 0.5
# 使用调整后的梯度下降法拟合模型
params_adjusted = gradient_descent_adjusted(example_func, sx, sy)
# 绘制调整后的结果
plt.scatter(sx, sy, label="Data Points")
plt.plot(sx, example_func(sx, *params_adjusted), color="red", label="Fitted Line Adjusted")
plt.legend()
plt.xlabel("sx")
plt.ylabel("sy")
plt.title("Fitting with Adjusted Gradient Descent")
plt.show()