You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.7 KiB

"""
models包功能
该包主要实现了4种主要的时间序列预测模型向量自回归(VAR)自回归移动平均(ARIMA)
季节性自回归移动平均 (SARIMA) 以及随机森林Random Forest
modules:
- utils.py: 包含用于数据读取和数据保存的通用函数
- VAR_model.py: 实现向量自回归 (VAR)预测模型
- ARIMA_model.py: 实现自回归移动平均 (ARIMA)预测模型
- SARIMA_model.py: 实现季节性自回归移动平均 (SARIMA)预测模型
- RF_model.py: 实现随机森林Random Forest预测模型
函数
- run(): 是整个预测过程的主函数该函数首先读取数据然后调用所有模型进行预测
最后把预测结果保存到文件中
使用示例:
from models import run
run_result = run(forecast_target=target_column, exog_columns=exog_columns_list, models=[VAR_Forecasting])
"""
from typing import List, Type
from . import utils, VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting, RF_Forecasting
__all__ = [
'utils', 'VAR_Forecasting', 'ARIMA_Forecasting', 'SARIMA_Forecasting',
'RF_Forecasting'
]
def run(
forecast_target: str,
exog_columns: List[str],
steps: int = 20,
file_path: str = 'data/normalized_df.csv',
models: List[Type] = [
VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting, RF_Forecasting
],
) -> None:
"""
执行数据读取预处理模型训练预测并保存预测结果等一系列步骤的主函数
Args:
forecast_target (str): 需要被预测的目标变量的列名
exog_columns (List[str]): 用于预测的特征变量的列名列表
steps (int, default=20): 需要进行预测的步数
file_path (str, default='data/normalized_df.csv'): 数据文件的路径
models (List[Type]) : 需要运行的预测模型列表默认包括VARARIMASARIMA和Random Forest模型
Returns:
None
"""
# 载入数据
input_df = utils.read_csv(file_path)
# 使用每个模型进行预测并保存结果
for model in models:
try:
model_name = model.__name__.split('.')[-1]
print(f"正在执行 {model_name} 模型进行预测...")
# 调用模型进行预测
model_df = model.run(input_df, forecast_target, exog_columns,
steps)
# 保存预测结果
utils.save_csv(model_df, f'data/{model_name}_df.csv')
print(f"{model_name} 模型的预测结果已保存至 data/{model_name}_df.csv")
except Exception as e:
print(f"{model_name} 模型预测过程出现错误: {e}")
print("所有模型预测都已完成。")