You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
models包功能
该包主要实现了4种主要的时间序列预测模型向量自回归(VAR)、自回归移动平均(ARIMA)、
季节性自回归移动平均 (SARIMA) 以及随机森林Random Forest
modules:
- utils.py: 包含用于数据读取和数据保存的通用函数。
- VAR_model.py: 实现向量自回归 (VAR)预测模型。
- ARIMA_model.py: 实现自回归移动平均 (ARIMA)预测模型。
- SARIMA_model.py: 实现季节性自回归移动平均 (SARIMA)预测模型。
- RF_model.py: 实现随机森林Random Forest预测模型。
函数:
- run(): 是整个预测过程的主函数。该函数首先读取数据,然后调用所有模型进行预测,
最后把预测结果保存到文件中。
使用示例:
from models import run
run_result = run(forecast_target=target_column, exog_columns=exog_columns_list, models=[VAR_Forecasting])
"""
from typing import List, Type
from . import utils, VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting, RF_Forecasting
__all__ = [
'utils', 'VAR_Forecasting', 'ARIMA_Forecasting', 'SARIMA_Forecasting',
'RF_Forecasting'
]
def run(
forecast_target: str,
exog_columns: List[str],
steps: int = 20,
file_path: str = 'data/normalized_df.csv',
models: List[Type] = [
VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting, RF_Forecasting
],
) -> None:
"""
执行数据读取、预处理、模型训练、预测并保存预测结果等一系列步骤的主函数。
Args:
forecast_target (str): 需要被预测的目标变量的列名。
exog_columns (List[str]): 用于预测的特征变量的列名列表。
steps (int, default=20): 需要进行预测的步数。
file_path (str, default='data/normalized_df.csv'): 数据文件的路径。
models (List[Type]) : 需要运行的预测模型列表默认包括VAR、ARIMA、SARIMA和Random Forest模型。
Returns:
None
"""
# 载入数据
input_df = utils.read_csv(file_path)
# 使用每个模型进行预测并保存结果
for model in models:
try:
model_name = model.__name__.split('.')[-1]
print(f"正在执行 {model_name} 模型进行预测...")
# 调用模型进行预测
model_df = model.run(input_df, forecast_target, exog_columns,
steps)
# 保存预测结果
utils.save_csv(model_df, f'data/{model_name}_df.csv')
print(f"{model_name} 模型的预测结果已保存至 data/{model_name}_df.csv")
except Exception as e:
print(f"{model_name} 模型预测过程出现错误: {e}")
print("所有模型预测都已完成。")