You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.1 KiB

5 months ago
import json
from typing import List
import numpy as np
# from statsmodels.tsa.api import VAR
import pandas as pd
import statsmodels.api as sm
def convert_timestamp_index(data: pd.DataFrame,
to_period: bool) -> pd.DataFrame:
"""
根据to_period参数选择将数据的时间索引转换为DatetimeIndex或PeriodIndex
Args:
data (pd.DataFrame): 输入的数据
to_period (bool): 如果为True则将DatetimeIndex转换为PeriodIndex
如果为False则将PeriodIndex转换为DatetimeIndex
Returns:
pd.DataFrame: 索引被转换后的数据
"""
if to_period:
data.index = pd.DatetimeIndex(data.index).to_period('D')
else:
data.index = data.index.to_timestamp()
return data
def train_VAR_model(data: pd.DataFrame, max_lags: int = 30):
"""
利用输入的时间序列数据训练VAR模型通过比较BIC值确定最优滞后阶数
Args:
data (pd.DataFrame): 用于模型训练的时间序列数据
max_lags (int, default=30): 最大滞后阶数默认为 30
Returns:
VARResultsWrapper: 训练得到的VAR模型
"""
model = sm.tsa.VAR(data)
criteria = []
lags = range(1, max_lags + 1)
# 通过比较每个滞后阶数模型的BIC值选择最优滞后阶数
for lag in lags:
result = model.fit(maxlags=lag)
criteria.append(result.bic)
# 使用最优滞后阶数再次训练模型
best_lag = lags[criteria.index(min(criteria))]
results = model.fit(maxlags=best_lag)
return results
def VAR_run(input_data: pd.DataFrame,
forecast_target: str,
_: List[str],
steps: int = 20) -> pd.DataFrame:
"""
运行函数执行一系列步骤包括索引转换训练模型数据预测
Args:
input_data (pd.DataFrame): 输入的DataFrame数据
forecast_target (str): 需要被预测的目标变量的列名
_ (List[str]): 占位参数用于保持和其他模型函数的接口一致性
steps (int, default=20): 预测步数
Returns:
pd.DataFrame: 预测结果的DataFrame对象
"""
input_data = input_data.replace([np.inf, -np.inf], np.nan).dropna()
# 将DataFrame对象的时间索引转换为PeriodIndex
input_data = convert_timestamp_index(input_data, to_period=True)
# 添加正则化项以确保协方差矩阵正定
input_data += np.random.normal(0, 1e-10, input_data.shape)
# 训练 VAR 模型
model = train_VAR_model(input_data, max_lags=10)
# 将DataFrame对象的时间索引转回原样
input_data = convert_timestamp_index(input_data, to_period=False)
# 利用VAR模型进行预测
pred = model.forecast(input_data.values[-model.k_ar:], steps=steps)
forecast_df = pd.DataFrame(
pred,
index=pd.date_range(start=input_data.index.max() +
pd.Timedelta(days=1),
periods=steps),
columns=input_data.columns)
return forecast_df[forecast_target]