You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
from typing import List
import numpy as np
# from statsmodels.tsa.api import VAR
import pandas as pd
import statsmodels.api as sm
def convert_timestamp_index(data: pd.DataFrame,
to_period: bool) -> pd.DataFrame:
"""
根据to_period参数选择将数据的时间索引转换为DatetimeIndex或PeriodIndex。
Args:
data (pd.DataFrame): 输入的数据。
to_period (bool): 如果为True则将DatetimeIndex转换为PeriodIndex
如果为False则将PeriodIndex转换为DatetimeIndex。
Returns:
pd.DataFrame: 索引被转换后的数据。
"""
if to_period:
data.index = pd.DatetimeIndex(data.index).to_period('D')
else:
data.index = data.index.to_timestamp()
return data
def train_VAR_model(data: pd.DataFrame, max_lags: int = 30):
"""
利用输入的时间序列数据训练VAR模型通过比较BIC值确定最优滞后阶数。
Args:
data (pd.DataFrame): 用于模型训练的时间序列数据。
max_lags (int, default=30): 最大滞后阶数,默认为 30。
Returns:
VARResultsWrapper: 训练得到的VAR模型。
"""
model = sm.tsa.VAR(data)
criteria = []
lags = range(1, max_lags + 1)
# 通过比较每个滞后阶数模型的BIC值选择最优滞后阶数
for lag in lags:
result = model.fit(maxlags=lag)
criteria.append(result.bic)
# 使用最优滞后阶数再次训练模型
best_lag = lags[criteria.index(min(criteria))]
results = model.fit(maxlags=best_lag)
return results
def VAR_run(input_data: pd.DataFrame,
forecast_target: str,
_: List[str],
steps: int = 20) -> pd.DataFrame:
"""
运行函数,执行一系列步骤,包括索引转换、训练模型、数据预测。
Args:
input_data (pd.DataFrame): 输入的DataFrame数据。
forecast_target (str): 需要被预测的目标变量的列名。
_ (List[str]): 占位参数,用于保持和其他模型函数的接口一致性。
steps (int, default=20): 预测步数。
Returns:
pd.DataFrame: 预测结果的DataFrame对象。
"""
input_data = input_data.replace([np.inf, -np.inf], np.nan).dropna()
# 将DataFrame对象的时间索引转换为PeriodIndex
input_data = convert_timestamp_index(input_data, to_period=True)
# 添加正则化项以确保协方差矩阵正定
input_data += np.random.normal(0, 1e-10, input_data.shape)
# 训练 VAR 模型
model = train_VAR_model(input_data, max_lags=10)
# 将DataFrame对象的时间索引转回原样
input_data = convert_timestamp_index(input_data, to_period=False)
# 利用VAR模型进行预测
pred = model.forecast(input_data.values[-model.k_ar:], steps=steps)
forecast_df = pd.DataFrame(
pred,
index=pd.date_range(start=input_data.index.max() +
pd.Timedelta(days=1),
periods=steps),
columns=input_data.columns)
return forecast_df[forecast_target]