You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
3.1 KiB

5 months ago
from typing import List, Union
import numpy as np
import pandas as pd
import pmdarima as pm
def train_ARIMA_model(endog: Union[np.ndarray, pd.Series],
exog: Union[np.ndarray, pd.DataFrame] = None,
exog_pred: Union[np.ndarray, pd.DataFrame] = None,
steps: int = 20,
information_criterion: str = 'aic') -> np.ndarray:
"""
使用ARIMA模型对时间序列数据进行预测
Args:
endog (Union[np.ndarray, pd.Series]): 要分析的时间序列数据
exog (Union[np.ndarray, pd.DataFrame], optional): 用于改进ARIMA模型的外生变量默认为None
exog_pred (Union[np.ndarray, pd.DataFrame], optional): 预测期间的外生变量必须与训练期间的外生变量列数一致默认为None
steps (int, optional, default=20): 预测期的长度
information_criterion (str, optional, default='aic'): 选择模型的信息准则'aic''bic'
Returns:
np.ndarray: 预测结果
"""
model = pm.auto_arima(endog,
X=exog,
seasonal=False,
information_criterion=information_criterion)
pred = model.predict(n_periods=steps, X=exog_pred)
return pred
def ARIMA_run(input_data: pd.DataFrame,
forecast_target: str,
exog_columns: List[str],
steps: int = 20) -> pd.DataFrame:
"""
主运行函数用以读取数据训练模型预测数据
Args:
input_data (pd.DataFrame): 输入的时间序列数据
forecast_target (str): 需要被预测的目标变量的列名
exog_columns (List[str]): 外生变量的列名列表
steps (int, optional, default=20): 预测步长
Returns:
pd.DataFrame: 预测结果的DataFrame对象
"""
# 创建一个未来日期的索引,用于保存预测数据
future_index = pd.date_range(start=input_data.index.max() +
pd.Timedelta(days=1),
periods=steps)
# 创建一个用于保存预测外生变量的空数据帧
df_exog = pd.DataFrame(index=future_index)
# 循环每个外生变量使用ARIMA模型进行训练和预测然后将预测值保存到df_exog中
for exog in exog_columns:
pred = train_ARIMA_model(endog=input_data[exog], steps=steps)
df_exog[exog] = pred
# 使用ARIMA模型对目标变量进行训练和预测注意这里将df_exog作为预测阶段的外生变量传入
pred = train_ARIMA_model(endog=input_data[forecast_target],
exog=input_data[exog_columns],
exog_pred=df_exog[exog_columns],
steps=steps,
information_criterion='bic')
# 根据预测值创建一个新的数据帧,用于保存预测的目标变量
forecast_df = pd.DataFrame(pred,
index=future_index,
columns=[forecast_target])
return forecast_df