|
|
"""
|
|
|
模块功能:对 PostgreSQL 数据库进行数据提取,数据处理,数据保存。
|
|
|
|
|
|
函数:
|
|
|
connect_to_pg(db_parameters, tables): 连接到 PostgreSQL 数据库并获取指定数据表。
|
|
|
merge_dfs(dfs): 合并 DataFrame 对象。
|
|
|
normalize_df(df): 对 DataFrame 对象进行最小最大标准化。
|
|
|
save_as_csv(df, file_path): 将 DataFrame 保存为 CSV 文件。
|
|
|
|
|
|
使用示例:
|
|
|
dfs = connect_to_pg(db_parameters, tables)
|
|
|
df = merge_dfs(dfs)
|
|
|
df = normalize_df(df)
|
|
|
save_as_csv(df, file_path)
|
|
|
|
|
|
注意:
|
|
|
本模块依赖于 psycopg2 库,请确保已正确安装该库。
|
|
|
"""
|
|
|
|
|
|
import pandas as pd
|
|
|
import psycopg2
|
|
|
|
|
|
from typing import Dict, List
|
|
|
|
|
|
|
|
|
def connect_to_pg(db_parameters: Dict[str, str],
|
|
|
tables: List[str]) -> Dict[str, pd.DataFrame]:
|
|
|
"""
|
|
|
连接到 PostgreSQL 数据库并获取指定数据表。
|
|
|
|
|
|
Args:
|
|
|
db_parameters (Dict[str, str]): 数据库连接参数的字典
|
|
|
tables (List[str]): 需要查询的数据表名称列表
|
|
|
|
|
|
Returns:
|
|
|
dfs (Dict[str, DataFrame]):
|
|
|
键是数据表名(去除 'app01_' 前缀),值是对应数据表的 DataFrame 对象。
|
|
|
"""
|
|
|
try:
|
|
|
# 用 with 语句确保数据库连接被正确关闭
|
|
|
with psycopg2.connect(**db_parameters) as conn:
|
|
|
dfs = {
|
|
|
table[6:]: pd.read_sql(f'select * from public.{table};', conn)
|
|
|
for table in tables
|
|
|
}
|
|
|
print('{:*^30}'.format('成功链接 PostgreSQL'))
|
|
|
except Exception as error:
|
|
|
print(f"发现错误:{error}")
|
|
|
exit()
|
|
|
|
|
|
return dfs
|
|
|
|
|
|
|
|
|
def merge_dfs(dfs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
|
|
|
"""
|
|
|
合并 DataFrame 对象。
|
|
|
|
|
|
Args:
|
|
|
dfs (dict): 键为数据表名,值为对应数据表的 DataFrame 对象。
|
|
|
|
|
|
Returns:
|
|
|
df (DataFrame): 合并后的 DataFrame 对象。
|
|
|
"""
|
|
|
# 通过合并所有数据表的日期范围来创建一个新的日期索引
|
|
|
date_range = pd.date_range(start=min(df['date'].min()
|
|
|
for df in dfs.values()),
|
|
|
end=max(df['date'].max()
|
|
|
for df in dfs.values()))
|
|
|
|
|
|
# 创建一个以日期范围为索引的空 DataFrame
|
|
|
df_merged = pd.DataFrame(index=date_range)
|
|
|
# print(df_merged)
|
|
|
|
|
|
# 遍历并合并每个数据表,保留日期索引并丢弃 'id' 列
|
|
|
for df in dfs.values():
|
|
|
df = df.set_index('date').reindex(date_range)
|
|
|
df = df.drop(columns='id') # 删除 'id' 列
|
|
|
df_merged = pd.concat([df_merged, df], axis=1)
|
|
|
|
|
|
# 对缺失值进行线性插值(其他方法?多项插值?)
|
|
|
df_merged = df_merged.interpolate()
|
|
|
|
|
|
# 如果有剩余的NaN值,删除这些行
|
|
|
df_merged.dropna(inplace=True)
|
|
|
|
|
|
return df_merged
|
|
|
|
|
|
|
|
|
def normalize_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
"""
|
|
|
对 DataFrame 对象进行最小最大标准化。
|
|
|
|
|
|
Args:
|
|
|
df (DataFrame): 要进行标准化的 DataFrame 对象。
|
|
|
|
|
|
Returns:
|
|
|
df_normalized (DataFrame): 进行最小最大标准化后的 DataFrame 对象。
|
|
|
"""
|
|
|
# 如果列的数据类型是布尔值、有符号整型、无符号整型、浮点数或复数浮点数的话,就进行最大最小标准化,否则保留原列的数据
|
|
|
df_normalized = df.apply(lambda x: (x - x.min()) / (x.max() - x.min())
|
|
|
if x.dtype.kind in 'biufc' else x)
|
|
|
|
|
|
return df_normalized
|
|
|
|
|
|
|
|
|
def save_as_csv(df: pd.DataFrame,
|
|
|
file_path: str = 'data/normalized_df.csv') -> None:
|
|
|
"""
|
|
|
将 DataFrame 保存为 CSV 文件。
|
|
|
|
|
|
Args:
|
|
|
df (DataFrame): 要保存的 DataFrame 对象。
|
|
|
file_path (str): 保存文件的路径。
|
|
|
|
|
|
"""
|
|
|
try:
|
|
|
df.to_csv(file_path, index_label='date')
|
|
|
print(f"成功保存为 {file_path}")
|
|
|
except Exception as e:
|
|
|
print(f"保存文件时出错: {e}")
|