You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
可解释的信贷风险评估系统
"""
import os
import sys
import subprocess
import webbrowser
import time
def print_system_overview():
"""
打印系统概述
"""
print("=" * 60)
print("可解释的信贷风险评估系统")
print("=" * 60)
print("本系统基于LightGBM和对抗自编码器技术提供以下功能")
print("1. 信贷风险预测")
print("2. 模型决策解释")
print("3. 数据可视化分析")
print("4. Web API接口")
print("=" * 60)
def check_dependencies():
"""
检查必要的依赖包
"""
required_packages = ['numpy', 'pandas', 'sklearn', 'xgboost', 'lightgbm', 'torch', 'flask', 'shap', 'matplotlib', 'seaborn']
missing_packages = []
for package in required_packages:
try:
if package == 'sklearn':
import sklearn
else:
__import__(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
print(f"缺少以下依赖包: {', '.join(missing_packages)}")
print("请运行: pip install " + " ".join(missing_packages))
return False
return True
def generate_sample_data():
"""
生成示例数据(如果不存在)
"""
data_path = "data/credit_data.csv"
if not os.path.exists(data_path):
print("生成示例信贷数据...")
sys.path.append('.')
from data.data_generator import generate_credit_data
df = generate_credit_data(10000)
df.to_csv(data_path, index=False)
print("示例数据已生成")
else:
print("示例数据已存在")
def train_models():
"""
训练模型(如果模型不存在)
"""
model_path = "models/lightgbm_model.pkl"
if not os.path.exists(model_path):
print("训练LightGBM模型...")
subprocess.run([sys.executable, "models/train_lightgbm.py"], check=True)
print("LightGBM模型训练完成")
else:
print("LightGBM模型已存在")
aae_path = "models/adversarial_autoencoder.pth"
if not os.path.exists(aae_path):
print("训练对抗自编码器...")
subprocess.run([sys.executable, "models/train_aae.py"], check=True)
print("对抗自编码器训练完成")
else:
print("对抗自编码器已存在")
def generate_explanations():
"""
生成模型解释(如果解释文件不存在)
"""
explanation_path = "visualization/shap_summary.png"
if not os.path.exists(explanation_path):
print("生成模型解释...")
subprocess.run([sys.executable, "utils/shap_explainer.py"], check=True)
print("模型解释生成完成")
else:
print("模型解释已存在")
def create_visualizations():
"""
创建可视化图表(如果图表不存在)
"""
viz_path = "visualization/dashboard.html"
if not os.path.exists(viz_path):
print("创建可视化图表...")
subprocess.run([sys.executable, "visualization/create_dashboard.py"], check=True)
print("可视化图表创建完成")
else:
print("可视化图表已存在")
def start_api_server():
"""
启动API服务器
"""
print("启动API服务器...")
print("服务器将在 http://127.0.0.1:5000 上运行")
print("按 Ctrl+C 停止服务器")
# 启动Flask应用
os.chdir('api')
subprocess.run([sys.executable, "app.py"], check=True)
def main():
"""
主函数
"""
print_system_overview()
if not check_dependencies():
return
# 创建必要的目录
directories = ['data', 'models', 'visualization', 'api']
for directory in directories:
if not os.path.exists(directory):
os.makedirs(directory)
try:
# 生成数据
generate_sample_data()
# 训练模型
train_models()
# 生成解释
generate_explanations()
# 创建可视化
create_visualizations()
# 启动API服务器
start_api_server()
except KeyboardInterrupt:
print("\n系统已停止")
except Exception as e:
print(f"系统运行出错: {e}")
if __name__ == "__main__":
main()