You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.1 KiB

5 months ago
import numpy as np
import matplotlib.pyplot as plt
def least_square_method(m, sampleData):
x = sampleData[:, 0]
y = sampleData[:, 1]
# 标准化x
x_mean = x.mean()
x_std = x.std()
x_normalized = (x - x_mean) / x_std
# 使用标准化的x构造A矩阵
A = np.vander(x_normalized, m + 1, increasing=True)
# 计算拟合系数theta使用最小二乘法的正规方程
theta = np.linalg.inv(A.T @ A) @ A.T @ y
# 计算残差平方和和其他统计量
residuals = y - A @ theta
residuals_squared_sum = residuals.T @ residuals
degrees_of_freedom = len(y) - (m + 1)
mean_squared_error = residuals_squared_sum / degrees_of_freedom
covariance_matrix = np.linalg.inv(A.T @ A) * mean_squared_error
return theta, covariance_matrix, x_mean, x_std
def compute_curveData(low, high, step, theta, m, x_mean, x_std):
x_fit = np.arange(low, high, step)
# 将x_fit标准化
x_fit_normalized = (x_fit - x_mean) / x_std
A_fit = np.vander(x_fit_normalized, m + 1, increasing=True)
y_fit = A_fit @ theta
return np.column_stack((x_fit, y_fit))
def draw_dots_and_line(curveData, sampleData, low, high):
x_fit = curveData[:, 0]
y_fit = curveData[:, 1]
x = sampleData[:, 0]
y = sampleData[:, 1]
plt.plot(x_fit, y_fit, label='FitCurve')
plt.scatter(x, y, color='red', label='SampleData')
plt.xlim(low, high)
plt.legend()
plt.show()
def random_points(n, low, high):
np.random.seed(0) # For reproducibility
# x = np.random.uniform(low, high, n)
# y = np.random.uniform(low, high, n) # This is just an example
x = np.array([0, 100, 200, 300, 400, 500, 600, 700, 800, 900])
y = np.array([10, 20, 10, 50, 80, 130, 210, 340, 550, 890])
return np.column_stack((x, y))
if __name__ == '__main__':
sampleData = random_points(20, 0, 1000)
print(len(sampleData))
m = 3
theta, covariance_matrix, x_mean, x_std = least_square_method(m, sampleData)
curveData = compute_curveData(0, 1000, 1, theta, m, x_mean, x_std)
draw_dots_and_line(curveData, sampleData, 0, 1000)