You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
URL/案例11 铁路旅客客流量预测/code/9.4 构建模型并预测节假日客流量.py

319 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 代码9-6
#ARIMA模型
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import math
import itertools
import numpy as np
import statsmodels.api as sm
from statsmodels.graphics.tsaplots import plot_acf
from statsmodels.stats.diagnostic import acorr_ljungbox
from statsmodels.tsa.stattools import adfuller as ADF
from statsmodels.tsa.arima_model import ARIMA
import matplotlib.ticker as ticker
on_h = pd.read_csv('../tmp/on_h.csv', index_col=0
,encoding='utf-8')
train = pd.DataFrame(on_h.iloc[0:426, 0])
test = pd.DataFrame(on_h.iloc[426:, 0])
#train.plot(title = '训练集时序图') # 画训练集时序图
x = train.index
y = train.on_man
fig, ax = plt.subplots(1,1)
ax.plot(x, y)
ticker_spacing = 70
ax.xaxis.set_major_locator(ticker.MultipleLocator(ticker_spacing))
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.xticks(rotation = 45)
plt.xlabel('日期')
plt.ylabel('客流量')
plt.title('训练集时序图')
plt.tight_layout() # tight_layout()方法可以保证图像的完整度
plt.show()
plot_acf(train,lags=400)
plt.xlabel('日期索引')
plt.ylabel('自相关系数')
plt.title('训练集自相关图')
plt.tight_layout()
plt.show() # 画训练集自相关图
print('原始序列的ADF检验结果为', ADF(train['on_man'])) # 检验训练集平稳性
print('白噪声检验结果为:', acorr_ljungbox(train['on_man'], lags=1))
# 代码9-7
import statsmodels.api as sm
train['on_man'] = train['on_man'].astype(float)
# 定阶
bic_matrix = [] # bic矩阵选择bic矩阵中最小值对应的行(p),列(q)
# 存在部分报错所以用try来跳过报错
for p in range(11):
tmp = []
for q in range(5):
try:
tmp.append(sm.tsa.ARIMA(train, order=(p, 1, q)).fit().bic)
except:
tmp.append(None)
bic_matrix.append(tmp)
bic_matrix = pd.DataFrame(bic_matrix) # 从中可以找出最小值
p, q = bic_matrix.stack().idxmin() # 先用stack展平然后用idxmin找出最小值位置
print('BIC最小的p值和q值为%s%s' % (p, q))
model = sm.tsa.ARIMA(train, order=(p, 1, q)).fit() # 建立ARIMA(p, 1, 1)模型
summary = model.summary() # 给出一份模型报告
forecast = model.forecast(10)
print('10天的预测结果、标准误差和置信区间分别为\n', forecast)
# pre = pd.DataFrame(forecast, columns = ['predict'])
pre = pd.DataFrame({'predict':forecast})
pre.index = test.index #如果索引不同将pre加到test中时会出错
test['pre'] = pre
plt.plot(test.index, test.on_man)
plt.plot(test.index, test.pre, linestyle=':')
plt.xticks(rotation = 45)
plt.legend(['真实', '预测'])
plt.title('预测结果和实际结果的对比')
plt.xlabel('日期')
plt.ylabel('上车人数')
plt.tight_layout() # tight_layout()方法可以保证图像的完整度
plt.show()
# 代码9-8
# 剔除节假日之后的时序图
on_nh = on_h[on_h.iloc[:,2] != '小长假'] # 剔除假期
x = on_nh.index
y = on_nh.on_man
fig, ax = plt.subplots(1,1)
ax.plot(x, y)
plt.xticks(range(len(x)), on_nh['date.1'])
ticker_spacing = 70
ax.xaxis.set_major_locator(ticker.MultipleLocator(ticker_spacing))
plt.xticks(rotation = 45)
plt.title('剔除节假日后ST111-01站点客流量')
plt.xlabel('日期')
plt.ylabel('客流量(人)')
plt.tight_layout() # tight_layout()方法可以保证图像的完整度
plt.show()
# 代码9-9
# 剔除节假日及前一天后的时序图
new_h = on_h #存放节假日更新的数据
new_h.index = range(len(new_h))
for i in range(1, len(new_h)):
if new_h.iloc[i, 2] == '小长假' :
new_h.iloc[i-1, 2] = '小长假'
new_nh = new_h[new_h.iloc[:,2] != '小长假'] # 剔除节假日及其前一天
plt.plot(new_nh['date.1'], new_nh['on_man'],color='green')
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(70))
plt.xticks(rotation = 45)
plt.title('剔除节假日及其前一天后ST111-01站点客流量')
plt.legend(['on_man'], loc=2)
plt.xlabel('日期')
plt.ylabel('客流量(人)')
plt.tight_layout() # tight_layout()方法可以保证图像的完整度
plt.savefig('../tmp/9-10.png', dpi=1080)
plt.show()
# 代码9-10
new_nh.index = range(len(new_nh))
train1 = pd.DataFrame(new_nh.iloc[0:378, 0])
test1 = pd.DataFrame(new_nh.iloc[378:, 0])
p = q = range(4)
d = range(2)
pdq = list(itertools.product(p, d, q))
seasonal_pdq = [(x[0], x[1], x[2], 7)for x in list(itertools.product(p, d, q))]
print('Examples of parameter combinations for Seasonal ARIMA...')
print('SARIMAX: {} x {}'.format(pdq[1], seasonal_pdq[1]))
print('SARIMAX: {} x {}'.format(pdq[1], seasonal_pdq[2]))
print('SARIMAX: {} x {}'.format(pdq[2], seasonal_pdq[3]))
print('SARIMAX: {} x {}'.format(pdq[2], seasonal_pdq[4]))
warnings.filterwarnings("ignore") # specify to ignore warning messages
sa =[]
for param in pdq:
for param_seasonal in seasonal_pdq:
try:
mod = sm.tsa.statespace.SARIMAX(train1,
order=param,
seasonal_order=param_seasonal,
enforce_stationarity=False,
enforce_invertibility=False)
results = mod.fit()
print('ARIMA{}x{}7 - AIC:{}'.format(param, param_seasonal, results.aic))
sa.append(param)
sa.append(param_seasonal)
sa.append(results.aic)
except:
continue
AIC = [i for i in sa if type(i) == np.float64]
AIC_min = min(AIC)
for i in np.arange(2,len(sa),3):
if sa[i] == min(AIC):
param = sa[i-2]
param_seasonal = sa[i-1]
mod = sm.tsa.statespace.SARIMAX(train1,
order=(param),
seasonal_order=(param_seasonal),
enforce_stationarity=False,
enforce_invertibility=False)
print('模型最终定阶为:', (param, param_seasonal))
results = mod.fit()
print(results.summary().tables[1])
fig = plt.figure(figsize=(15, 12))
results.plot_diagnostics(figsize=(15, 12), fig=fig)
plt.show()
pre_10 = results.predict(start=378, end=387,dynamic=True)
out_pre = pd.DataFrame(np.zeros([10,3]),columns = ['real', 'pre', 'error'])
out_pre['real'] = list(test1['on_man'])
out_pre['pre'] = list(pre_10)
# 计算相对误差
error_seasonal = (out_pre.loc[:, 'pre']-out_pre.loc[:,'real'])/out_pre.loc[:,'real']
# 平均相对误差
error_mean = abs(error_seasonal).mean()
print('预测平均相对误差为:', error_mean)
# 代码9-11
# 节假日客流规律
# 由于之前步骤对on_h有修改所以此处重新载入
holiday = pd.read_csv('../tmp/holiday.csv', index_col=0
,encoding = 'utf-8')
Train_Station = pd.read_csv('../tmp/Train_Station.csv', index_col=0
,encoding='utf-8')
Train_ST111_01 = Train_Station[Train_Station.iloc[:, 0] == 'ST111-01']
on_h = Train_ST111_01.groupby('date')['on_man'].sum()
on_h = pd.DataFrame(on_h)
on_h['date'] = 0
on_h['holiday'] = 0
# 添加日期和类型(工作日或者小长假)
for i in range(len(holiday)):
for j in range(len(on_h)):
if holiday.iloc[i,0] == on_h.index[j]:
on_h.loc[on_h.index[j], 'holiday'] = holiday.iloc[i,1]
on_h.loc[on_h.index[j], 'date'] = holiday.iloc[i,0]
# 2015春节
fig = plt.figure(figsize=(12, 6)) # 设置画布
ax = fig.add_subplot(1, 1, 1)
ax.plot(on_h.loc['2015-01-19':'2015-02-18', 'on_man'], color = 'blue')
ax.plot(on_h.loc['2015-02-18':'2015-02-25', 'on_man'], color = 'red', linestyle=':')
ax.plot(on_h.loc['2015-02-25':'2015-03-01', 'on_man'], color = 'blue')
plt.xlabel('日期')
plt.ylabel('上车人数')
plt.title('2015春节客流量')
plt.legend(['工作日','节假日'])
plt.xticks(rotation = 45)
plt.show()
# 2015劳动节
fig1 = plt.figure(figsize=(12, 6)) # 设置画布
ax1 = fig1.add_subplot(1, 1, 1)
ax1.plot(on_h.loc['2015-04-27':'2015-05-01', 'on_man'], color = 'blue')
ax1.plot(on_h.loc['2015-05-01':'2015-05-04', 'on_man'], color = 'red', linestyle=':')
ax1.plot(on_h.loc['2015-05-04':'2015-05-11', 'on_man'], color = 'blue')
plt.xlabel('日期')
plt.ylabel('上车人数')
plt.title('五一客流量')
plt.legend(['工作日','节假日'])
plt.xticks(rotation = 45)
plt.show()
# 国庆和中秋
fig2 = plt.figure(figsize=(12, 6)) # 设置画布
ax2 = fig2.add_subplot(1, 1, 1)
ax2.plot(on_h.loc['2015-09-21':'2015-09-26', 'on_man'],color = 'blue')
ax2.plot(on_h.loc['2015-09-26':'2015-09-28', 'on_man'],color = 'red', linestyle=':')
ax2.plot(on_h.loc['2015-09-28':'2015-09-30', 'on_man'],color = 'blue')
ax2.plot(on_h.loc['2015-09-30':'2015-10-08', 'on_man'],color = 'red', linestyle=':')
ax2.plot(on_h.loc['2015-10-08':'2015-10-12', 'on_man'],color = 'blue')
plt.xlabel('日期')
plt.ylabel('上车人数')
plt.title('中秋和国庆客流量')
plt.legend(['工作日', '节假日'])
plt.xticks(rotation = 45)
plt.show()
# 代码9-12
# 2015和2016春节客流量比较
compare = pd.DataFrame(on_h.loc['2015-02-05':'2015-02-26', 'on_man'])
compare['2016'] = list(on_h.loc['2016-01-25':'2016-02-15', 'on_man'])
compare.columns = ['2015', '2016']
compare.index = range(len(compare))
plt.plot(compare.index, compare['2015'], linestyle=':')
plt.plot(compare.index, compare['2016'])
plt.legend(['2015', '2016'])
plt.xlabel('日期')
plt.ylabel('客流量')
plt.title('2015和2016春节客流量比较')
# 代码9-13
# 2015春节节假日波动系数
M = on_h.loc['2015-01-01':'2015-02-17']
M = M[M.loc[:,'holiday'] != '小长假']
M1 = on_h.loc['2015-01-01':'2015-02-07']
M1 = M1[M1.loc[:,'holiday'] != '小长假']
B_coef = []
# 春节前10天
for i in on_h.loc['2015-02-08':'2015-02-17',:].index:
B_coef.append('%.2f' % (on_h.loc[i, 'on_man']/math.ceil(M1.iloc[-30:].mean())))
# 春节及春节后两天
for i in on_h.loc['2015-02-18':'2015-02-26',:].index:
B_coef.append('%.2f' % (on_h.loc[i, 'on_man']/math.ceil(M.iloc[-30:].mean())))
B_coef = pd.DataFrame(B_coef)
B_coef.columns = ['on_man']
B_coef[u'on_man'] = B_coef[u'on_man'].astype(float)
fig3 = plt.figure(figsize=(8,6)) # 设置画布
ax3 = fig3.add_subplot(1, 1, 1)
ax3.plot(B_coef.iloc[:,0], color='blue')
plt.xlabel('Index')
plt.ylabel('系数')
plt.title('2015春节客流量波动系数')
# 设置数字标签
for a, b in zip(B_coef.index, B_coef.iloc[:, 0]):
plt.text(a, b, b, ha='center', va='bottom', fontsize=20)
plt.legend()
plt.show()
# 预测2016年春节客流量
MM = on_h.loc['2015-01-01':'2016-02-06',:]
MM = MM[MM.loc[:, 'holiday'] != '小长假']
MM_mean = math.ceil(MM.iloc[-30:].mean())
MM1 = on_h.loc['2015-01-01':'2016-01-27',:]
MM1 = MM1[MM1.loc[:, 'holiday'] != '小长假']
MM1_mean = math.ceil(MM1.iloc[-30:].mean())
pre_2016_b = B_coef.iloc[0:10, 0] * MM1_mean
pre_2016_a = B_coef.iloc[10:, 0] * MM_mean
pre_2016 = pd.DataFrame(on_h.loc['2016-01-28':'2016-02-15', 'on_man'])
pre_2016['pre'] = 0
pre_2016.loc[0:10,'pre'] = list(pre_2016_b)
pre_2016.loc[10:,'pre'] = list(pre_2016_a)
pre_2016.columns=['real', 'pre']
plt.plot(pre_2016.index, pre_2016.real)
plt.plot(pre_2016.index, pre_2016.pre, linestyle=':')
plt.xticks(rotation=45)
plt.legend(['real', 'pre'])
plt.xlabel('日期')
plt.ylabel('客流量')
plt.title('预测2016年春节客流量')
# 计算相对误差
error_pre = (pre_2016.loc[:, 'pre'] - pre_2016.loc[:, 'real'])/pre_2016.loc[:, 'real']
# 平均相对误差
error_pre_mean = abs(error_pre).mean()
print('预测的平均相对误差为:', error_pre_mean)