You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Influenza_fund_linkage_system/models/VAR_Forecasting.py

99 lines
3.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
模块功能:
本模块核心功能是实现向量自回归VAR预测模型的训练和运行。
函数:
- convert_timestamp_index(data, to_period): 转换时间序列数据的时间索引至DatetimeIndex或PeriodIndex。
- train_VAR_model(data, max_lags): 利用输入的时间序列数据训练 VAR 模型。
- run(input_data, steps): 执行时间索引转换、模型训练和数据预测等步骤,并返回预测结果。
"""
import pandas as pd
import statsmodels.api as sm
from typing import List
def convert_timestamp_index(data: pd.DataFrame,
to_period: bool) -> pd.DataFrame:
"""
根据to_period参数选择将数据的时间索引转换为DatetimeIndex或PeriodIndex。
Args:
data (pd.DataFrame): 输入的数据。
to_period (bool): 如果为True则将DatetimeIndex转换为PeriodIndex
如果为False则将PeriodIndex转换为DatetimeIndex。
Returns:
pd.DataFrame: 索引被转换后的数据。
"""
if to_period:
data.index = pd.DatetimeIndex(data.index).to_period('D')
else:
data.index = data.index.to_timestamp()
return data
def train_VAR_model(data: pd.DataFrame, max_lags: int = 30):
"""
利用输入的时间序列数据训练VAR模型通过比较BIC值确定最优滞后阶数。
Args:
data (pd.DataFrame): 用于模型训练的时间序列数据。
max_lags (int, default=30): 最大滞后阶数,默认为 30。
Returns:
VARResultsWrapper: 训练得到的VAR模型。
"""
model = sm.tsa.VAR(data)
criteria = []
lags = range(1, max_lags + 1)
# 通过比较每个滞后阶数模型的BIC值选择最优滞后阶数
for lag in lags:
result = model.fit(maxlags=lag)
criteria.append(result.bic)
# 使用最优滞后阶数再次训练模型
best_lag = lags[criteria.index(min(criteria))]
results = model.fit(maxlags=best_lag)
return results
def run(input_data: pd.DataFrame,
forecast_target: str,
_: List[str],
steps: int = 20) -> pd.DataFrame:
"""
运行函数,执行一系列步骤,包括索引转换、训练模型、数据预测。
Args:
input_data (pd.DataFrame): 输入的DataFrame数据。
forecast_target (str): 需要被预测的目标变量的列名。
_ (List[str]): 占位参数,用于保持和其他模型函数的接口一致性。
steps (int, default=20): 预测步数。
Returns:
pd.DataFrame: 预测结果的DataFrame对象。
"""
# 将DataFrame对象的时间索引转换为PeriodIndex
input_data = convert_timestamp_index(input_data, to_period=True)
# 训练 VAR 模型
model = train_VAR_model(input_data, max_lags=60)
# 将DataFrame对象的时间索引转回原样
input_data = convert_timestamp_index(input_data, to_period=False)
# 利用VAR模型进行预测
pred = model.forecast(input_data.values[-model.k_ar:], steps=steps)
forecast_df = pd.DataFrame(
pred,
index=pd.date_range(start=input_data.index.max() +
pd.Timedelta(days=1),
periods=steps),
columns=input_data.columns)
return forecast_df[forecast_target]