You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
from sklearn.svm import LinearSVR
from sklearn.metrics import explained_variance_score, mean_absolute_error, mean_squared_error, median_absolute_error, \
r2_score
import os
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
# ==================== 1. 数据探索与相关性分析 ====================
print("=" * 60)
print("步骤1: 数据探索与相关性分析")
print("=" * 60)
inputfile = './data/data.csv'
data = pd.read_csv(inputfile)
print(f"数据维度: {data.shape}")
print("\n相关系数矩阵:")
corr_matrix = np.round(data.corr(method='pearson'), 2)
print(corr_matrix)
# 可视化相关性热力图
plt.figure(figsize=(12, 10))
plt.imshow(corr_matrix, cmap='coolwarm', aspect='auto')
plt.colorbar()
plt.xticks(range(len(corr_matrix.columns)), corr_matrix.columns, rotation=90)
plt.yticks(range(len(corr_matrix.columns)), corr_matrix.columns)
plt.title('特征相关性热力图')
plt.tight_layout()
plt.savefig('output/correlation_heatmap.png', dpi=300)
plt.show()
# ==================== 2. Lasso特征选择 ====================
print("\n" + "=" * 60)
print("步骤2: Lasso回归特征选择")
print("=" * 60)
# 分离特征和目标变量
X = data.iloc[:, 0:13]
y = data['y']
# Lasso回归
lasso = Lasso(alpha=1000, max_iter=10000)
lasso.fit(X, y)
print('Lasso回归系数:')
for col, coef in zip(X.columns, lasso.coef_):
print(f' {col}: {coef:.5f}')
print(f'\n非零系数个数: {np.sum(lasso.coef_ != 0)}')
# 获取选中的特征
mask = lasso.coef_ != 0
selected_features = X.columns[mask].tolist()
print(f'选中特征: {selected_features}')
# 保存筛选后的数据
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
new_reg_data = X.iloc[:, mask].copy()
new_reg_data['y'] = y.values
new_reg_data.to_csv(f'{output_dir}/selected_features.csv', index=False)
print(f'\n特征选择后数据维度: {new_reg_data.shape}')
# ==================== 3. 灰色预测GM(1,1) ====================
print("\n" + "=" * 60)
print("步骤3: 灰色预测预测2014-2015年特征值")
print("=" * 60)
from GM11 import GM11
# 设置索引为年份
new_reg_data.index = range(1994, 2014)
new_reg_data.loc[2014] = None
new_reg_data.loc[2015] = None
# 对选中特征进行灰色预测
for feature in selected_features:
print(f'正在预测 {feature}...')
# 获取历史数据并转换为数组
historical_data = new_reg_data.loc[range(1994, 2014), feature].values
# GM11预测
f, a, b, x0_0, C, P = GM11(historical_data)
# 预测2014和2015年
pred_2014 = f(len(new_reg_data) - 1)
pred_2015 = f(len(new_reg_data))
new_reg_data.loc[2014, feature] = pred_2014
new_reg_data.loc[2015, feature] = pred_2015
# 保留两位小数
new_reg_data[feature] = new_reg_data[feature].round(2)
print(f' 模型参数: a={a:.4f}, b={b:.4f}')
print(f' 方差比C: {C:.4f}, 小残差概率P: {P:.4f}')
if C < 0.35 and P > 0.95:
print(' 模型精度: 优秀')
elif C < 0.5 and P > 0.8:
print(' 模型精度: 合格')
elif C < 0.65 and P > 0.7:
print(' 模型精度: 基本合格')
else:
print(' 模型精度: 不合格')
# 保存灰色预测结果
new_reg_data.to_excel(f'{output_dir}/gm11_prediction.xlsx')
print(f'\n2014-2015年特征预测结果:')
print(new_reg_data.loc[2014:2015, selected_features])
# ==================== 4. SVR模型训练与预测 ====================
print("\n" + "=" * 60)
print("步骤4: SVR模型训练与财政收入预测")
print("=" * 60)
# 读取灰色预测后的数据
data = pd.read_excel(f'{output_dir}/gm11_prediction.xlsx', index_col=0)
# 准备训练数据
data_train = data.loc[range(1994, 2014)].copy()
# 数据标准化
data_mean = data_train.mean()
data_std = data_train.std()
data_train_norm = (data_train - data_mean) / data_std
# 分离特征和标签
x_train = data_train_norm[selected_features].values
y_train = data_train_norm['y'].values
# 训练Linear SVR模型
svr = LinearSVR(max_iter=10000)
svr.fit(x_train, y_train)
# 预测所有年份包括2014-2015
data_norm = (data[selected_features] - data_mean[selected_features]) / data_std[selected_features]
data['y_pred'] = svr.predict(data_norm) * data_std['y'] + data_mean['y']
# 保存预测结果
data.to_excel(f'{output_dir}/final_prediction.xlsx')
print(f'预测完成!结果保存至: {output_dir}/final_prediction.xlsx')
# ==================== 5. 模型评估 ====================
print("\n" + "=" * 60)
print("步骤5: 模型评估")
print("=" * 60)
# 计算训练集评估指标
y_true_train = data.loc[range(1994, 2014), 'y']
y_pred_train = data.loc[range(1994, 2014), 'y_pred']
print('训练集评估指标:')
print(f' 解释方差分: {explained_variance_score(y_true_train, y_pred_train):.4f}')
print(f' R²分数: {r2_score(y_true_train, y_pred_train):.4f}')
print(f' 均方误差: {mean_squared_error(y_true_train, y_pred_train):.4f}')
print(f' 平均绝对误差: {mean_absolute_error(y_true_train, y_pred_train):.4f}')
# ==================== 6. 结果可视化 ====================
print("\n" + "=" * 60)
print("步骤6: 结果可视化")
print("=" * 60)
# 绘制对比图
plt.figure(figsize=(14, 7))
plt.plot(data.index, data['y'], 'b-o', label='真实值', linewidth=2, markersize=6)
plt.plot(data.index, data['y_pred'], 'r-*', label='预测值', linewidth=2, markersize=6)
plt.axvline(x=2013.5, color='g', linestyle='--', alpha=0.5, label='预测起点')
plt.xlabel('年份', fontsize=12)
plt.ylabel('财政收入', fontsize=12)
plt.title('财政收入真实值 vs 预测值', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{output_dir}/prediction_comparison.png', dpi=300)
plt.show()
# 显示预测结果
print("\n最终预测结果:")
print(data.loc[2014:2015, ['y', 'y_pred']])