You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import matplotlib.pyplot as plt
def least_square_method(m, sampleData):
x = sampleData[:, 0]
y = sampleData[:, 1]
# 标准化x
x_mean = x.mean()
x_std = x.std()
x_normalized = (x - x_mean) / x_std
# 使用标准化的x构造A矩阵
A = np.vander(x_normalized, m + 1, increasing=True)
# 计算拟合系数theta使用最小二乘法的正规方程
theta = np.linalg.inv(A.T @ A) @ A.T @ y
# 计算残差平方和和其他统计量
residuals = y - A @ theta
residuals_squared_sum = residuals.T @ residuals
degrees_of_freedom = len(y) - (m + 1)
mean_squared_error = residuals_squared_sum / degrees_of_freedom
covariance_matrix = np.linalg.inv(A.T @ A) * mean_squared_error
return theta, covariance_matrix, x_mean, x_std
def compute_curveData(low, high, step, theta, m, x_mean, x_std):
x_fit = np.arange(low, high, step)
# 将x_fit标准化
x_fit_normalized = (x_fit - x_mean) / x_std
A_fit = np.vander(x_fit_normalized, m + 1, increasing=True)
y_fit = A_fit @ theta
return np.column_stack((x_fit, y_fit))
def draw_dots_and_line(curveData, sampleData, low, high):
x_fit = curveData[:, 0]
y_fit = curveData[:, 1]
x = sampleData[:, 0]
y = sampleData[:, 1]
plt.plot(x_fit, y_fit, label='FitCurve')
plt.scatter(x, y, color='red', label='SampleData')
plt.xlim(low, high)
plt.legend()
plt.show()
def random_points(n, low, high):
np.random.seed(0) # For reproducibility
# x = np.random.uniform(low, high, n)
# y = np.random.uniform(low, high, n) # This is just an example
x = np.array([0, 100, 200, 300, 400, 500, 600, 700, 800, 900])
y = np.array([10, 20, 10, 50, 80, 130, 210, 340, 550, 890])
return np.column_stack((x, y))
if __name__ == '__main__':
sampleData = random_points(20, 0, 1000)
print(len(sampleData))
m = 3
theta, covariance_matrix, x_mean, x_std = least_square_method(m, sampleData)
curveData = compute_curveData(0, 1000, 1, theta, m, x_mean, x_std)
draw_dots_and_line(curveData, sampleData, 0, 1000)