|
|
|
|
@ -0,0 +1,319 @@
|
|
|
|
|
# 代码9-6
|
|
|
|
|
# ARIMA模型
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import warnings
|
|
|
|
|
import math
|
|
|
|
|
import itertools
|
|
|
|
|
import numpy as np
|
|
|
|
|
import statsmodels.api as sm
|
|
|
|
|
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
|
|
|
|
|
from statsmodels.stats.diagnostic import acorr_ljungbox
|
|
|
|
|
from statsmodels.tsa.stattools import adfuller as ADF
|
|
|
|
|
from statsmodels.tsa.arima.model import ARIMA
|
|
|
|
|
import matplotlib.ticker as ticker
|
|
|
|
|
|
|
|
|
|
print("正在读取数据...")
|
|
|
|
|
on_h = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/on_h.csv', encoding='utf-8')
|
|
|
|
|
print(f"数据形状: {on_h.shape}")
|
|
|
|
|
|
|
|
|
|
# 确保数据列名正确
|
|
|
|
|
if 'on_man' not in on_h.columns:
|
|
|
|
|
# 尝试找到正确的列名
|
|
|
|
|
for col in on_h.columns:
|
|
|
|
|
if 'on_man' in col.lower() or 'total' in col.lower():
|
|
|
|
|
on_h = on_h.rename(columns={col: 'on_man'})
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 分割训练集和测试集
|
|
|
|
|
train_size = 426
|
|
|
|
|
if len(on_h) > train_size:
|
|
|
|
|
train = pd.DataFrame(on_h.iloc[0:train_size, :]['on_man'])
|
|
|
|
|
test = pd.DataFrame(on_h.iloc[train_size:, :]['on_man'])
|
|
|
|
|
else:
|
|
|
|
|
train = pd.DataFrame(on_h['on_man'])
|
|
|
|
|
test = pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
print(f"训练集大小: {len(train)}, 测试集大小: {len(test)}")
|
|
|
|
|
|
|
|
|
|
# 训练集时序图
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
x = train.index
|
|
|
|
|
y = train['on_man']
|
|
|
|
|
plt.plot(x, y)
|
|
|
|
|
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
plt.xlabel('日期索引')
|
|
|
|
|
plt.ylabel('客流量')
|
|
|
|
|
plt.title('训练集时序图')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 自相关图
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
plot_acf(train['on_man'], lags=min(40, len(train) - 1), ax=plt.gca())
|
|
|
|
|
plt.xlabel('滞后阶数')
|
|
|
|
|
plt.ylabel('自相关系数')
|
|
|
|
|
plt.title('训练集自相关图')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 平稳性检验和白噪声检验
|
|
|
|
|
print('原始序列的ADF检验结果为:', ADF(train['on_man'].dropna()))
|
|
|
|
|
print('白噪声检验结果为:', acorr_ljungbox(train['on_man'].dropna(), lags=1))
|
|
|
|
|
|
|
|
|
|
# 代码9-7
|
|
|
|
|
print("\n开始ARIMA模型定阶...")
|
|
|
|
|
train['on_man'] = train['on_man'].astype(float)
|
|
|
|
|
|
|
|
|
|
# 定阶
|
|
|
|
|
bic_matrix = []
|
|
|
|
|
p_range = range(0, 4) # 减小范围提高效率
|
|
|
|
|
q_range = range(0, 3)
|
|
|
|
|
|
|
|
|
|
for p in p_range:
|
|
|
|
|
tmp = []
|
|
|
|
|
for q in q_range:
|
|
|
|
|
try:
|
|
|
|
|
model = ARIMA(train['on_man'], order=(p, 1, q))
|
|
|
|
|
results = model.fit()
|
|
|
|
|
tmp.append(results.bic)
|
|
|
|
|
print(f'ARIMA({p},1,{q}) - BIC: {results.bic:.2f}')
|
|
|
|
|
except Exception as e:
|
|
|
|
|
tmp.append(np.inf)
|
|
|
|
|
print(f'ARIMA({p},1,{q}) - 失败: {e}')
|
|
|
|
|
bic_matrix.append(tmp)
|
|
|
|
|
|
|
|
|
|
bic_matrix = pd.DataFrame(bic_matrix, index=p_range, columns=q_range)
|
|
|
|
|
print("\nBIC矩阵:")
|
|
|
|
|
print(bic_matrix)
|
|
|
|
|
|
|
|
|
|
# 找到最小BIC值
|
|
|
|
|
min_bic = bic_matrix.min().min()
|
|
|
|
|
p, q = bic_matrix.stack().idxmin()
|
|
|
|
|
print(f'BIC最小的p值和q值为:{p}、{q}')
|
|
|
|
|
|
|
|
|
|
# 建立最终模型
|
|
|
|
|
try:
|
|
|
|
|
model = ARIMA(train['on_man'], order=(p, 1, q))
|
|
|
|
|
model_fit = model.fit()
|
|
|
|
|
print("\n模型摘要:")
|
|
|
|
|
print(model_fit.summary())
|
|
|
|
|
|
|
|
|
|
# 预测
|
|
|
|
|
if len(test) > 0:
|
|
|
|
|
forecast_steps = min(10, len(test))
|
|
|
|
|
forecast = model_fit.forecast(steps=forecast_steps)
|
|
|
|
|
forecast_index = test.index[:forecast_steps]
|
|
|
|
|
|
|
|
|
|
# 创建预测结果DataFrame
|
|
|
|
|
pre = pd.DataFrame({
|
|
|
|
|
'predict': forecast.values
|
|
|
|
|
}, index=forecast_index)
|
|
|
|
|
|
|
|
|
|
test_with_pred = test.copy()
|
|
|
|
|
test_with_pred['pre'] = pre['predict']
|
|
|
|
|
|
|
|
|
|
# 绘制预测结果
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
plt.plot(test_with_pred.index, test_with_pred['on_man'], label='真实值')
|
|
|
|
|
plt.plot(test_with_pred.index, test_with_pred['pre'], linestyle='--', label='预测值')
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.title('预测结果和实际结果的对比')
|
|
|
|
|
plt.xlabel('日期索引')
|
|
|
|
|
plt.ylabel('上车人数')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 计算预测误差
|
|
|
|
|
mape = np.mean(np.abs((test_with_pred['on_man'] - test_with_pred['pre']) / test_with_pred['on_man'])) * 100
|
|
|
|
|
print(f'预测平均相对误差(MAPE): {mape:.2f}%')
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"模型拟合失败: {e}")
|
|
|
|
|
|
|
|
|
|
# 代码9-8
|
|
|
|
|
print("\n分析剔除节假日后的数据...")
|
|
|
|
|
# 确保列名正确
|
|
|
|
|
if 'holiday' not in on_h.columns:
|
|
|
|
|
# 尝试找到节假日列
|
|
|
|
|
for col in on_h.columns:
|
|
|
|
|
if 'holiday' in col.lower() or '假期' in col.lower():
|
|
|
|
|
on_h = on_h.rename(columns={col: 'holiday'})
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 剔除假期
|
|
|
|
|
on_nh = on_h[on_h['holiday'] != '小长假'].copy()
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
plt.plot(on_nh.index, on_nh['on_man'])
|
|
|
|
|
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
plt.title('剔除节假日后ST111-01站点客流量')
|
|
|
|
|
plt.xlabel('日期索引')
|
|
|
|
|
plt.ylabel('客流量(人)')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 代码9-9
|
|
|
|
|
print("\n分析剔除节假日及前一天后的数据...")
|
|
|
|
|
new_h = on_h.copy()
|
|
|
|
|
new_h.index = range(len(new_h))
|
|
|
|
|
|
|
|
|
|
# 标记节假日前一天
|
|
|
|
|
for i in range(1, len(new_h)):
|
|
|
|
|
if new_h.iloc[i]['holiday'] == '小长假':
|
|
|
|
|
new_h.loc[i - 1, 'holiday'] = '小长假'
|
|
|
|
|
|
|
|
|
|
# 剔除节假日及其前一天
|
|
|
|
|
new_nh = new_h[new_h['holiday'] != '小长假'].copy()
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
# 使用日期列如果存在,否则使用索引
|
|
|
|
|
if 'date' in new_nh.columns:
|
|
|
|
|
plt.plot(new_nh['date'], new_nh['on_man'], color='green')
|
|
|
|
|
else:
|
|
|
|
|
plt.plot(new_nh.index, new_nh['on_man'], color='green')
|
|
|
|
|
|
|
|
|
|
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
plt.title('剔除节假日及其前一天后ST111-01站点客流量')
|
|
|
|
|
plt.legend(['客流量'], loc='upper right')
|
|
|
|
|
plt.xlabel('日期')
|
|
|
|
|
plt.ylabel('客流量(人)')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 代码9-10
|
|
|
|
|
print("\n开始季节性ARIMA分析...")
|
|
|
|
|
if len(new_nh) > 400:
|
|
|
|
|
train1_size = 378
|
|
|
|
|
train1 = pd.DataFrame(new_nh.iloc[0:train1_size]['on_man'])
|
|
|
|
|
test1 = pd.DataFrame(new_nh.iloc[train1_size:]['on_man'])
|
|
|
|
|
else:
|
|
|
|
|
train1 = pd.DataFrame(new_nh['on_man'])
|
|
|
|
|
test1 = pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
print(f"季节性分析 - 训练集: {len(train1)}, 测试集: {len(test1)}")
|
|
|
|
|
|
|
|
|
|
# 简化参数搜索以提高效率
|
|
|
|
|
p = q = range(0, 3)
|
|
|
|
|
d = range(1, 2)
|
|
|
|
|
pdq = list(itertools.product(p, d, q))
|
|
|
|
|
seasonal_pdq = [(x[0], x[1], x[2], 7) for x in pdq]
|
|
|
|
|
|
|
|
|
|
print('参数组合示例:')
|
|
|
|
|
for i in range(min(3, len(pdq))):
|
|
|
|
|
print(f'SARIMAX: {pdq[i]} x {seasonal_pdq[i]}')
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
best_aic = np.inf
|
|
|
|
|
best_order = None
|
|
|
|
|
best_seasonal_order = None
|
|
|
|
|
|
|
|
|
|
# 限制搜索数量
|
|
|
|
|
max_combinations = 10
|
|
|
|
|
combinations_tried = 0
|
|
|
|
|
|
|
|
|
|
for param in pdq[:3]: # 限制p,d,q组合
|
|
|
|
|
for param_seasonal in seasonal_pdq[:3]: # 限制季节性组合
|
|
|
|
|
if combinations_tried >= max_combinations:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
mod = sm.tsa.statespace.SARIMAX(train1['on_man'],
|
|
|
|
|
order=param,
|
|
|
|
|
seasonal_order=param_seasonal,
|
|
|
|
|
enforce_stationarity=False,
|
|
|
|
|
enforce_invertibility=False)
|
|
|
|
|
results = mod.fit(disp=False)
|
|
|
|
|
current_aic = results.aic
|
|
|
|
|
print(f'ARIMA{param}x{param_seasonal} - AIC:{current_aic:.2f}')
|
|
|
|
|
|
|
|
|
|
if current_aic < best_aic:
|
|
|
|
|
best_aic = current_aic
|
|
|
|
|
best_order = param
|
|
|
|
|
best_seasonal_order = param_seasonal
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f'ARIMA{param}x{param_seasonal} - 失败')
|
|
|
|
|
|
|
|
|
|
combinations_tried += 1
|
|
|
|
|
|
|
|
|
|
if best_order is not None:
|
|
|
|
|
print(f'\n最佳模型: ARIMA{best_order}x{best_seasonal_order} - AIC: {best_aic:.2f}')
|
|
|
|
|
|
|
|
|
|
# 拟合最佳模型
|
|
|
|
|
mod = sm.tsa.statespace.SARIMAX(train1['on_man'],
|
|
|
|
|
order=best_order,
|
|
|
|
|
seasonal_order=best_seasonal_order,
|
|
|
|
|
enforce_stationarity=False,
|
|
|
|
|
enforce_invertibility=False)
|
|
|
|
|
results = mod.fit(disp=False)
|
|
|
|
|
|
|
|
|
|
# 预测
|
|
|
|
|
if len(test1) > 0:
|
|
|
|
|
forecast_steps = min(10, len(test1))
|
|
|
|
|
pre_10 = results.get_forecast(steps=forecast_steps)
|
|
|
|
|
forecast_mean = pre_10.predicted_mean
|
|
|
|
|
|
|
|
|
|
# 计算误差
|
|
|
|
|
out_pre = pd.DataFrame({
|
|
|
|
|
'real': test1['on_man'].values[:forecast_steps],
|
|
|
|
|
'pre': forecast_mean.values
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
error_seasonal = (out_pre['pre'] - out_pre['real']) / out_pre['real']
|
|
|
|
|
error_mean = abs(error_seasonal).mean()
|
|
|
|
|
print(f'季节性模型预测平均相对误差: {error_mean:.4f}')
|
|
|
|
|
|
|
|
|
|
# 绘制诊断图
|
|
|
|
|
results.plot_diagnostics(figsize=(12, 8))
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
# 代码9-11
|
|
|
|
|
print("\n分析节假日客流规律...")
|
|
|
|
|
# 重新读取数据确保一致性
|
|
|
|
|
try:
|
|
|
|
|
holiday = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/holiday.csv', encoding='utf-8')
|
|
|
|
|
Train_Station = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/Train_Station.csv', encoding='utf-8')
|
|
|
|
|
|
|
|
|
|
# 处理ST111-01站点数据
|
|
|
|
|
Train_ST111_01 = Train_Station[Train_Station.iloc[:, 0] == 'ST111-01']
|
|
|
|
|
on_h_new = Train_ST111_01.groupby('date')['on_man'].sum().reset_index()
|
|
|
|
|
|
|
|
|
|
# 合并节假日信息
|
|
|
|
|
on_h_new = pd.merge(on_h_new, holiday, left_on='date', right_on=holiday.columns[0], how='left')
|
|
|
|
|
on_h_new['holiday'] = on_h_new[holiday.columns[1]].fillna('工作日')
|
|
|
|
|
|
|
|
|
|
# 2015春节客流
|
|
|
|
|
spring_festival_2015 = on_h_new[
|
|
|
|
|
(on_h_new['date'] >= '2015-01-19') & (on_h_new['date'] <= '2015-03-01')
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not spring_festival_2015.empty:
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
dates = pd.to_datetime(spring_festival_2015['date'])
|
|
|
|
|
plt.plot(dates, spring_festival_2015['on_man'])
|
|
|
|
|
|
|
|
|
|
# 标记节假日
|
|
|
|
|
holiday_dates = spring_festival_2015[spring_festival_2015['holiday'] == '小长假']['date']
|
|
|
|
|
holiday_values = spring_festival_2015[spring_festival_2015['holiday'] == '小长假']['on_man']
|
|
|
|
|
plt.scatter(pd.to_datetime(holiday_dates), holiday_values, color='red', s=50)
|
|
|
|
|
|
|
|
|
|
plt.xlabel('日期')
|
|
|
|
|
plt.ylabel('上车人数')
|
|
|
|
|
plt.title('2015春节客流量')
|
|
|
|
|
plt.legend(['客流量', '节假日'])
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"节假日分析失败: {e}")
|
|
|
|
|
|
|
|
|
|
print("所有分析完成!")
|