You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.5 KiB

5 months ago
import numpy as np
import matplotlib.pyplot as plt
# 示例函数定义,比如二次多项式
def func(x, theta):
# 二次多项式: theta[0] + theta[1]*x + theta[2]*x^2
return theta[0] + theta[1]*x + theta[2]*x**2
# 修改后的最小二乘法函数
def least_square_method(func, m, sampleData):
x = sampleData[:,0]
y = sampleData[:,1]
# 构造 A 矩阵,使用 func 函数
A = np.array([func(xi, np.arange(m + 1)) for xi in x])
print(A)
# (2) 多项式阶数m这里以3阶多项式为例
# 使用x的幂次来构造A矩阵
A = np.vander(x, m + 1)
# 计算拟合系数theta
theta = np.linalg.inv(A.T @ A) @ A.T @ y
# 其他步骤保持不变...
plt.scatter(x, y, color='red')
x_fit = np.linspace(min(x), max(x), 100)
y_fit = np.array([func(xi, theta) for xi in x_fit])
plt.plot(x_fit, y_fit)
plt.show()
# 计算残差平方和等...
residuals = y - np.dot(A, theta)
residuals_squared_sum = np.dot(residuals.T, residuals)
degrees_of_freedom = len(y) - (m + 1)
mean_squared_error = residuals_squared_sum / degrees_of_freedom
covariance_matrix = np.linalg.inv(A.T @ A) * mean_squared_error
return theta, covariance_matrix
# 示例用法
if __name__ == '__main__':
sampleData = np.array([[0, 1], [1, 2], [2, 1], [3, 5], [4, 8], [5, 13], [6, 21], [7, 34], [8, 55], [9, 89]])
m = 2 # 为二次多项式
theta, covariance_matrix = least_square_method(func, m, sampleData)
print("Theta:", theta)
print("Covariance Matrix:", covariance_matrix)