You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

50 lines
1.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
from scipy.optimize import approx_fprime
import matplotlib.pyplot as plt
# 调整梯度下降法以动态适应参数数量
def gradient_descent_adjusted(func, sx, sy, learning_rate=0.001, iterations=5000):
# 动态确定func需要的参数数量减去sx
params_count = func.__code__.co_argcount - 1
# 基于参数数量初始化theta
theta = np.random.randn(params_count) * 0.01
# 定义损失函数和梯度的计算
def loss_grad(theta):
pred = func(sx, *theta)
error = pred - sy
loss = np.mean(error ** 2)
grad = approx_fprime(theta, lambda t: np.mean((func(sx, *t) - sy) ** 2), epsilon=1e-6)
return loss, grad
# 梯度下降循环
for _ in range(iterations):
_, grad = loss_grad(theta)
theta -= learning_rate * grad
return theta
# 示例函数:简单的线性函数
def example_func(sx,c, a, b):
return c*sx**2 + a * sx + b
# 生成模拟数据
np.random.seed(0)
sx = np.linspace(-1, 1, 100)
sy = 0.1*sx**2+ 3 * sx + 2 + np.random.randn(100) * 0.5
# 使用调整后的梯度下降法拟合模型
params_adjusted = gradient_descent_adjusted(example_func, sx, sy)
# 绘制调整后的结果
plt.scatter(sx, sy, label="Data Points")
plt.plot(sx, example_func(sx, *params_adjusted), color="red", label="Fitted Line Adjusted")
plt.legend()
plt.xlabel("sx")
plt.ylabel("sy")
plt.title("Fitting with Adjusted Gradient Descent")
plt.show()