ADD file via upload

main
hut22412030120 5 months ago
parent 0b6fb819e7
commit 175caf2cf5

@ -0,0 +1,319 @@
# 代码9-6
# ARIMA模型
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import math
import itertools
import numpy as np
import statsmodels.api as sm
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.stats.diagnostic import acorr_ljungbox
from statsmodels.tsa.stattools import adfuller as ADF
from statsmodels.tsa.arima.model import ARIMA
import matplotlib.ticker as ticker
print("正在读取数据...")
on_h = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/on_h.csv', encoding='utf-8')
print(f"数据形状: {on_h.shape}")
# 确保数据列名正确
if 'on_man' not in on_h.columns:
# 尝试找到正确的列名
for col in on_h.columns:
if 'on_man' in col.lower() or 'total' in col.lower():
on_h = on_h.rename(columns={col: 'on_man'})
break
# 分割训练集和测试集
train_size = 426
if len(on_h) > train_size:
train = pd.DataFrame(on_h.iloc[0:train_size, :]['on_man'])
test = pd.DataFrame(on_h.iloc[train_size:, :]['on_man'])
else:
train = pd.DataFrame(on_h['on_man'])
test = pd.DataFrame()
print(f"训练集大小: {len(train)}, 测试集大小: {len(test)}")
# 训练集时序图
plt.figure(figsize=(12, 6))
x = train.index
y = train['on_man']
plt.plot(x, y)
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.xticks(rotation=45)
plt.xlabel('日期索引')
plt.ylabel('客流量')
plt.title('训练集时序图')
plt.tight_layout()
plt.show()
# 自相关图
plt.figure(figsize=(12, 6))
plot_acf(train['on_man'], lags=min(40, len(train) - 1), ax=plt.gca())
plt.xlabel('滞后阶数')
plt.ylabel('自相关系数')
plt.title('训练集自相关图')
plt.tight_layout()
plt.show()
# 平稳性检验和白噪声检验
print('原始序列的ADF检验结果为', ADF(train['on_man'].dropna()))
print('白噪声检验结果为:', acorr_ljungbox(train['on_man'].dropna(), lags=1))
# 代码9-7
print("\n开始ARIMA模型定阶...")
train['on_man'] = train['on_man'].astype(float)
# 定阶
bic_matrix = []
p_range = range(0, 4) # 减小范围提高效率
q_range = range(0, 3)
for p in p_range:
tmp = []
for q in q_range:
try:
model = ARIMA(train['on_man'], order=(p, 1, q))
results = model.fit()
tmp.append(results.bic)
print(f'ARIMA({p},1,{q}) - BIC: {results.bic:.2f}')
except Exception as e:
tmp.append(np.inf)
print(f'ARIMA({p},1,{q}) - 失败: {e}')
bic_matrix.append(tmp)
bic_matrix = pd.DataFrame(bic_matrix, index=p_range, columns=q_range)
print("\nBIC矩阵:")
print(bic_matrix)
# 找到最小BIC值
min_bic = bic_matrix.min().min()
p, q = bic_matrix.stack().idxmin()
print(f'BIC最小的p值和q值为{p}{q}')
# 建立最终模型
try:
model = ARIMA(train['on_man'], order=(p, 1, q))
model_fit = model.fit()
print("\n模型摘要:")
print(model_fit.summary())
# 预测
if len(test) > 0:
forecast_steps = min(10, len(test))
forecast = model_fit.forecast(steps=forecast_steps)
forecast_index = test.index[:forecast_steps]
# 创建预测结果DataFrame
pre = pd.DataFrame({
'predict': forecast.values
}, index=forecast_index)
test_with_pred = test.copy()
test_with_pred['pre'] = pre['predict']
# 绘制预测结果
plt.figure(figsize=(12, 6))
plt.plot(test_with_pred.index, test_with_pred['on_man'], label='真实值')
plt.plot(test_with_pred.index, test_with_pred['pre'], linestyle='--', label='预测值')
plt.xticks(rotation=45)
plt.legend()
plt.title('预测结果和实际结果的对比')
plt.xlabel('日期索引')
plt.ylabel('上车人数')
plt.tight_layout()
plt.show()
# 计算预测误差
mape = np.mean(np.abs((test_with_pred['on_man'] - test_with_pred['pre']) / test_with_pred['on_man'])) * 100
print(f'预测平均相对误差(MAPE): {mape:.2f}%')
except Exception as e:
print(f"模型拟合失败: {e}")
# 代码9-8
print("\n分析剔除节假日后的数据...")
# 确保列名正确
if 'holiday' not in on_h.columns:
# 尝试找到节假日列
for col in on_h.columns:
if 'holiday' in col.lower() or '假期' in col.lower():
on_h = on_h.rename(columns={col: 'holiday'})
break
# 剔除假期
on_nh = on_h[on_h['holiday'] != '小长假'].copy()
plt.figure(figsize=(12, 6))
plt.plot(on_nh.index, on_nh['on_man'])
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
plt.xticks(rotation=45)
plt.title('剔除节假日后ST111-01站点客流量')
plt.xlabel('日期索引')
plt.ylabel('客流量(人)')
plt.tight_layout()
plt.show()
# 代码9-9
print("\n分析剔除节假日及前一天后的数据...")
new_h = on_h.copy()
new_h.index = range(len(new_h))
# 标记节假日前一天
for i in range(1, len(new_h)):
if new_h.iloc[i]['holiday'] == '小长假':
new_h.loc[i - 1, 'holiday'] = '小长假'
# 剔除节假日及其前一天
new_nh = new_h[new_h['holiday'] != '小长假'].copy()
plt.figure(figsize=(12, 6))
# 使用日期列如果存在,否则使用索引
if 'date' in new_nh.columns:
plt.plot(new_nh['date'], new_nh['on_man'], color='green')
else:
plt.plot(new_nh.index, new_nh['on_man'], color='green')
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
plt.xticks(rotation=45)
plt.title('剔除节假日及其前一天后ST111-01站点客流量')
plt.legend(['客流量'], loc='upper right')
plt.xlabel('日期')
plt.ylabel('客流量(人)')
plt.tight_layout()
plt.show()
# 代码9-10
print("\n开始季节性ARIMA分析...")
if len(new_nh) > 400:
train1_size = 378
train1 = pd.DataFrame(new_nh.iloc[0:train1_size]['on_man'])
test1 = pd.DataFrame(new_nh.iloc[train1_size:]['on_man'])
else:
train1 = pd.DataFrame(new_nh['on_man'])
test1 = pd.DataFrame()
print(f"季节性分析 - 训练集: {len(train1)}, 测试集: {len(test1)}")
# 简化参数搜索以提高效率
p = q = range(0, 3)
d = range(1, 2)
pdq = list(itertools.product(p, d, q))
seasonal_pdq = [(x[0], x[1], x[2], 7) for x in pdq]
print('参数组合示例:')
for i in range(min(3, len(pdq))):
print(f'SARIMAX: {pdq[i]} x {seasonal_pdq[i]}')
warnings.filterwarnings("ignore")
best_aic = np.inf
best_order = None
best_seasonal_order = None
# 限制搜索数量
max_combinations = 10
combinations_tried = 0
for param in pdq[:3]: # 限制p,d,q组合
for param_seasonal in seasonal_pdq[:3]: # 限制季节性组合
if combinations_tried >= max_combinations:
break
try:
mod = sm.tsa.statespace.SARIMAX(train1['on_man'],
order=param,
seasonal_order=param_seasonal,
enforce_stationarity=False,
enforce_invertibility=False)
results = mod.fit(disp=False)
current_aic = results.aic
print(f'ARIMA{param}x{param_seasonal} - AIC:{current_aic:.2f}')
if current_aic < best_aic:
best_aic = current_aic
best_order = param
best_seasonal_order = param_seasonal
except Exception as e:
print(f'ARIMA{param}x{param_seasonal} - 失败')
combinations_tried += 1
if best_order is not None:
print(f'\n最佳模型: ARIMA{best_order}x{best_seasonal_order} - AIC: {best_aic:.2f}')
# 拟合最佳模型
mod = sm.tsa.statespace.SARIMAX(train1['on_man'],
order=best_order,
seasonal_order=best_seasonal_order,
enforce_stationarity=False,
enforce_invertibility=False)
results = mod.fit(disp=False)
# 预测
if len(test1) > 0:
forecast_steps = min(10, len(test1))
pre_10 = results.get_forecast(steps=forecast_steps)
forecast_mean = pre_10.predicted_mean
# 计算误差
out_pre = pd.DataFrame({
'real': test1['on_man'].values[:forecast_steps],
'pre': forecast_mean.values
})
error_seasonal = (out_pre['pre'] - out_pre['real']) / out_pre['real']
error_mean = abs(error_seasonal).mean()
print(f'季节性模型预测平均相对误差: {error_mean:.4f}')
# 绘制诊断图
results.plot_diagnostics(figsize=(12, 8))
plt.tight_layout()
plt.show()
# 代码9-11
print("\n分析节假日客流规律...")
# 重新读取数据确保一致性
try:
holiday = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/holiday.csv', encoding='utf-8')
Train_Station = pd.read_csv('E:/案例11 铁路旅客客流量预测/tmp/Train_Station.csv', encoding='utf-8')
# 处理ST111-01站点数据
Train_ST111_01 = Train_Station[Train_Station.iloc[:, 0] == 'ST111-01']
on_h_new = Train_ST111_01.groupby('date')['on_man'].sum().reset_index()
# 合并节假日信息
on_h_new = pd.merge(on_h_new, holiday, left_on='date', right_on=holiday.columns[0], how='left')
on_h_new['holiday'] = on_h_new[holiday.columns[1]].fillna('工作日')
# 2015春节客流
spring_festival_2015 = on_h_new[
(on_h_new['date'] >= '2015-01-19') & (on_h_new['date'] <= '2015-03-01')
]
if not spring_festival_2015.empty:
plt.figure(figsize=(12, 6))
dates = pd.to_datetime(spring_festival_2015['date'])
plt.plot(dates, spring_festival_2015['on_man'])
# 标记节假日
holiday_dates = spring_festival_2015[spring_festival_2015['holiday'] == '小长假']['date']
holiday_values = spring_festival_2015[spring_festival_2015['holiday'] == '小长假']['on_man']
plt.scatter(pd.to_datetime(holiday_dates), holiday_values, color='red', s=50)
plt.xlabel('日期')
plt.ylabel('上车人数')
plt.title('2015春节客流量')
plt.legend(['客流量', '节假日'])
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
except Exception as e:
print(f"节假日分析失败: {e}")
print("所有分析完成!")
Loading…
Cancel
Save