You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
2.8 KiB
96 lines
2.8 KiB
import pandas as pd
|
|
from streamlit_echarts import st_echarts
|
|
|
|
|
|
def draw_echarts(model_name: str, target: str, target_name: str,
|
|
history_data: pd.DataFrame, pred_data: pd.DataFrame):
|
|
"""
|
|
构造 ECharts 图表的配置并在 Streamlit 应用中展示。
|
|
|
|
Args:
|
|
model_name (str): 模型的名称
|
|
target (str): 目标值的列名
|
|
target_name (str): 目标值的显示名称
|
|
historical_data (pd.DataFrame): 历史数据
|
|
predicted_data (pd.DataFrame): 预测数据
|
|
|
|
Returns:
|
|
dict: ECharts 图表的配置
|
|
"""
|
|
# 数据处理,将历史数据和预测数据添加 None 值以适应图表的 x 轴
|
|
history_values = history_data[target].values.tolist() + [
|
|
None for _ in range(len(pred_data))
|
|
]
|
|
pred_values = [None for _ in range(len(history_data))
|
|
] + pred_data[target].values.tolist()
|
|
|
|
# 定义ECharts的配置
|
|
option = {
|
|
"title": {
|
|
"text": f"{model_name}模型",
|
|
"x": "auto"
|
|
},
|
|
# 配置提示框组件
|
|
"tooltip": {
|
|
"trigger": "axis"
|
|
},
|
|
# 配置图例组件
|
|
"legend": {
|
|
"data": [f"{target_name}历史数据", f"{target_name}预测数据"],
|
|
"left": "right"
|
|
},
|
|
# 配置x轴和y轴
|
|
"xAxis": {
|
|
"type":
|
|
"category",
|
|
"data":
|
|
history_data.index.astype(str).to_list() +
|
|
pred_data.index.astype(str).to_list()
|
|
},
|
|
"yAxis": {
|
|
"type": "value"
|
|
},
|
|
# 配置数据区域缩放组件
|
|
"dataZoom": [{
|
|
"type": "inside",
|
|
"start": 0,
|
|
"end": 100
|
|
}],
|
|
"series": []
|
|
}
|
|
|
|
# 添加历史数据系列
|
|
if any(history_values):
|
|
option["series"].append({
|
|
"name": f"{target_name}历史数据",
|
|
"type": "line",
|
|
"data": history_values,
|
|
"smooth": "true"
|
|
})
|
|
|
|
# 添加预测数据的系列
|
|
if any(pred_values):
|
|
option["series"].append({
|
|
"name": f"{target_name}预测数据",
|
|
"type": "line",
|
|
"data": pred_values,
|
|
"smooth": "true",
|
|
"lineStyle": {
|
|
"type": "dashed"
|
|
}
|
|
})
|
|
|
|
# 在Streamlit应用中展示ECharts图表
|
|
st_echarts(options=option)
|
|
return option
|
|
|
|
|
|
# history_data = pd.read_csv('data/normalized_df.csv',
|
|
# index_col="date",
|
|
# parse_dates=["date"])
|
|
# pred_data = pd.read_csv('data/VAR_Forecasting_df.csv',
|
|
# index_col="date",
|
|
# parse_dates=["date"])
|
|
# draw_echarts('VAR_Forecasting', 'liugan_index', '流感指数', history_data,
|
|
# pred_data)
|