You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
模块功能:对 PostgreSQL 数据库进行数据提取,数据处理,数据保存。
函数:
connect_to_pg(db_parameters, tables): 连接到 PostgreSQL 数据库并获取指定数据表。
merge_dfs(dfs): 合并 DataFrame 对象。
normalize_df(df): 对 DataFrame 对象进行最小最大标准化。
save_as_csv(df, file_path): 将 DataFrame 保存为 CSV 文件。
使用示例:
dfs = connect_to_pg(db_parameters, tables)
df = merge_dfs(dfs)
df = normalize_df(df)
save_as_csv(df, file_path)
注意:
本模块依赖于 psycopg2 库,请确保已正确安装该库。
"""
import pandas as pd
import psycopg2
from typing import Dict, List
def connect_to_pg(db_parameters: Dict[str, str],
tables: List[str]) -> Dict[str, pd.DataFrame]:
"""
连接到 PostgreSQL 数据库并获取指定数据表。
Args:
db_parameters (Dict[str, str]): 数据库连接参数的字典
tables (List[str]): 需要查询的数据表名称列表
Returns:
dfs (Dict[str, DataFrame]):
键是数据表名(去除 'app01_' 前缀),值是对应数据表的 DataFrame 对象。
"""
try:
# 用 with 语句确保数据库连接被正确关闭
with psycopg2.connect(**db_parameters) as conn:
dfs = {
table[6:]: pd.read_sql(f'select * from public.{table};', conn)
for table in tables
}
print('{:*^30}'.format('成功链接 PostgreSQL'))
except Exception as error:
print(f"发现错误:{error}")
exit()
return dfs
def merge_dfs(dfs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
"""
合并 DataFrame 对象。
Args:
dfs (dict): 键为数据表名,值为对应数据表的 DataFrame 对象。
Returns:
df (DataFrame): 合并后的 DataFrame 对象。
"""
# 通过合并所有数据表的日期范围来创建一个新的日期索引
date_range = pd.date_range(start=min(df['date'].min()
for df in dfs.values()),
end=max(df['date'].max()
for df in dfs.values()))
# 创建一个以日期范围为索引的空 DataFrame
df_merged = pd.DataFrame(index=date_range)
# print(df_merged)
# 遍历并合并每个数据表,保留日期索引并丢弃 'id' 列
for df in dfs.values():
df = df.set_index('date').reindex(date_range)
df = df.drop(columns='id') # 删除 'id' 列
df_merged = pd.concat([df_merged, df], axis=1)
# 对缺失值进行线性插值(其他方法?多项插值?)
df_merged = df_merged.interpolate()
# 如果有剩余的NaN值删除这些行
df_merged.dropna(inplace=True)
return df_merged
def normalize_df(df: pd.DataFrame) -> pd.DataFrame:
"""
对 DataFrame 对象进行最小最大标准化。
Args:
df (DataFrame): 要进行标准化的 DataFrame 对象。
Returns:
df_normalized (DataFrame): 进行最小最大标准化后的 DataFrame 对象。
"""
# 如果列的数据类型是布尔值、有符号整型、无符号整型、浮点数或复数浮点数的话,就进行最大最小标准化,否则保留原列的数据
df_normalized = df.apply(lambda x: (x - x.min()) / (x.max() - x.min())
if x.dtype.kind in 'biufc' else x)
return df_normalized
def save_as_csv(df: pd.DataFrame,
file_path: str = 'data/normalized_df.csv') -> None:
"""
将 DataFrame 保存为 CSV 文件。
Args:
df (DataFrame): 要保存的 DataFrame 对象。
file_path (str): 保存文件的路径。
"""
try:
df.to_csv(file_path, index_label='date')
print(f"成功保存为 {file_path}")
except Exception as e:
print(f"保存文件时出错: {e}")