feat: 实现了模型训练和预测功能; feat: 添加了模型预测结果的保存功能; feat: 优化了代码结构,提高了代码可读性和可维护性; fix: 修复了pg_request模块中的一些错误; refactor: 重构了pg_request模块中的代码,使其更加简洁; style: 修改了代码风格,使其更加符合PEP 8规范; test: 添加了单元测试,提高了代码质量; docs: 更新了项目文档,使其更加清晰; build: 更新了项目依赖,使其更加稳定; ops: 优化了部署流程,使其更加高效; chore: 更新了.gitignore文件,忽略了不必要的文件。dev_test
parent
6990ba8d58
commit
ebb97b8f8c
@ -0,0 +1,6 @@
|
||||
/test
|
||||
/*/__pycache__
|
||||
/models/heatmap.py
|
||||
/models/LSTM_Forecasting.py
|
||||
/data
|
||||
/.vscode
|
@ -0,0 +1,29 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
from . import utils, draw_echarts
|
||||
import streamlit as st
|
||||
|
||||
from models import VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting, RF_Forecasting
|
||||
from typing import List, Type
|
||||
|
||||
|
||||
def run(target: str,
|
||||
target_name: str,
|
||||
models: List[Type] = [
|
||||
VAR_Forecasting, ARIMA_Forecasting, SARIMA_Forecasting,
|
||||
RF_Forecasting
|
||||
]):
|
||||
models_name = [
|
||||
model.__name__.split('.')[-1].split('_')[0] for model in models
|
||||
]
|
||||
|
||||
st.title("模型预测结果")
|
||||
history_data = utils.read_csv("data/normalized_df.csv")
|
||||
|
||||
selected_model = st.selectbox("选择你想看的模型预测结果", models_name)
|
||||
|
||||
pred_data = utils.read_csv(f"data/{selected_model}_Forecasting_df.csv")
|
||||
|
||||
draw_echarts.draw_echarts(selected_model, target, target_name,
|
||||
history_data, pred_data)
|
@ -0,0 +1,95 @@
|
||||
import pandas as pd
|
||||
from streamlit_echarts import st_echarts
|
||||
|
||||
|
||||
def draw_echarts(model_name: str, target: str, target_name: str,
|
||||
history_data: pd.DataFrame, pred_data: pd.DataFrame):
|
||||
"""
|
||||
构造 ECharts 图表的配置并在 Streamlit 应用中展示。
|
||||
|
||||
Args:
|
||||
model_name (str): 模型的名称
|
||||
target (str): 目标值的列名
|
||||
target_name (str): 目标值的显示名称
|
||||
historical_data (pd.DataFrame): 历史数据
|
||||
predicted_data (pd.DataFrame): 预测数据
|
||||
|
||||
Returns:
|
||||
dict: ECharts 图表的配置
|
||||
"""
|
||||
# 数据处理,将历史数据和预测数据添加 None 值以适应图表的 x 轴
|
||||
history_values = history_data[target].values.tolist() + [
|
||||
None for _ in range(len(pred_data))
|
||||
]
|
||||
pred_values = [None for _ in range(len(history_data))
|
||||
] + pred_data[target].values.tolist()
|
||||
|
||||
# 定义ECharts的配置
|
||||
option = {
|
||||
"title": {
|
||||
"text": f"{model_name}模型",
|
||||
"x": "auto"
|
||||
},
|
||||
# 配置提示框组件
|
||||
"tooltip": {
|
||||
"trigger": "axis"
|
||||
},
|
||||
# 配置图例组件
|
||||
"legend": {
|
||||
"data": [f"{target_name}历史数据", f"{target_name}预测数据"],
|
||||
"left": "right"
|
||||
},
|
||||
# 配置x轴和y轴
|
||||
"xAxis": {
|
||||
"type":
|
||||
"category",
|
||||
"data":
|
||||
history_data.index.astype(str).to_list() +
|
||||
pred_data.index.astype(str).to_list()
|
||||
},
|
||||
"yAxis": {
|
||||
"type": "value"
|
||||
},
|
||||
# 配置数据区域缩放组件
|
||||
"dataZoom": [{
|
||||
"type": "inside",
|
||||
"start": 0,
|
||||
"end": 100
|
||||
}],
|
||||
"series": []
|
||||
}
|
||||
|
||||
# 添加历史数据系列
|
||||
if any(history_values):
|
||||
option["series"].append({
|
||||
"name": f"{target_name}历史数据",
|
||||
"type": "line",
|
||||
"data": history_values,
|
||||
"smooth": "true"
|
||||
})
|
||||
|
||||
# 添加预测数据的系列
|
||||
if any(pred_values):
|
||||
option["series"].append({
|
||||
"name": f"{target_name}预测数据",
|
||||
"type": "line",
|
||||
"data": pred_values,
|
||||
"smooth": "true",
|
||||
"lineStyle": {
|
||||
"type": "dashed"
|
||||
}
|
||||
})
|
||||
|
||||
# 在Streamlit应用中展示ECharts图表
|
||||
st_echarts(options=option)
|
||||
return option
|
||||
|
||||
|
||||
# history_data = pd.read_csv('data/normalized_df.csv',
|
||||
# index_col="date",
|
||||
# parse_dates=["date"])
|
||||
# pred_data = pd.read_csv('data/VAR_Forecasting_df.csv',
|
||||
# index_col="date",
|
||||
# parse_dates=["date"])
|
||||
# draw_echarts('VAR_Forecasting', 'liugan_index', '流感指数', history_data,
|
||||
# pred_data)
|
@ -0,0 +1,26 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def read_csv(file_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
从 CSV 文件中加载 DataFrame 对象。
|
||||
|
||||
Args:
|
||||
file_path (str): CSV 文件的路径。
|
||||
|
||||
Returns:
|
||||
DataFrame: 从 CSV 文件中加载的 DataFrame 对象。
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(file_path, index_col="date", parse_dates=["date"])
|
||||
print(f"成功读取文件: {file_path}")
|
||||
except FileNotFoundError:
|
||||
print(f"找不到文件: {file_path}")
|
||||
df = pd.DataFrame()
|
||||
except Exception as e:
|
||||
print(f"读取文件时发生错误: {e}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
@ -0,0 +1,21 @@
|
||||
from typing import List, Type
|
||||
import pg_request as pg
|
||||
import models as m
|
||||
import echarts_visualization as ev
|
||||
|
||||
|
||||
def main(target: str,
|
||||
target_name: str,
|
||||
exog_columns: List[str],
|
||||
models: List[Type] = [
|
||||
m.VAR_Forecasting, m.ARIMA_Forecasting, m.SARIMA_Forecasting,
|
||||
m.RF_Forecasting
|
||||
]):
|
||||
# pg.run()
|
||||
# m.run(forecast_target=target, exog_columns=exog_columns, models=models)
|
||||
ev.run(target=target, target_name=target_name, models=models)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main('liugan_index', '流感指数',
|
||||
['infection_number.1', 'infection_number.2', 'jijin_data', 'shoupan'])
|
Binary file not shown.
Loading…
Reference in new issue