You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import pandas as pd
import numpy as np
import joblib
import shap
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
import matplotlib.pyplot as plt
import os
# 添加项目根目录到Python路径
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.data_generator import preprocess_data
def explain_with_shap():
"""
使用SHAP解释LightGBM模型
"""
# 读取数据
print("读取信贷数据...")
df = pd.read_csv('data/credit_data.csv')
# 数据预处理
print("数据预处理...")
X, y, scaler, le_education, le_home, le_purpose = preprocess_data(df)
# 加载训练好的模型
print("加载模型...")
try:
model = joblib.load('models/lightgbm_model.pkl')
model_name = "LightGBM"
except:
# 如果LightGBM模型不存在回退到XGBoost模型
model = joblib.load('models/xgboost_model.pkl')
model_name = "XGBoost"
print(f"使用{model_name}模型")
# 选择一小部分数据进行解释(避免计算时间过长)
X_sample = X.iloc[:100] # 增加样本数量以获得更准确的解释
# 创建SHAP解释器
print("创建SHAP解释器...")
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_sample)
# 绘制特征重要性条形图
print("绘制SHAP特征重要性图...")
plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
plt.title(f'{model_name}模型SHAP特征重要性')
plt.tight_layout()
plt.savefig('visualization/shap_feature_importance.png', dpi=300, bbox_inches='tight')
plt.close()
# 绘制SHAP摘要图
print("绘制SHAP摘要图...")
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_sample, show=False)
plt.title(f'{model_name}模型SHAP摘要图')
plt.tight_layout()
plt.savefig('visualization/shap_summary.png', dpi=300, bbox_inches='tight')
plt.close()
# 保存SHAP解释器以便在API中使用
joblib.dump(explainer, 'models/shap_explainer.pkl')
print("SHAP解释完成图表已保存到 visualization 目录")
return shap_values, X_sample
if __name__ == "__main__":
shap_values, X_sample = explain_with_shap()
print("SHAP解释模块集成完成!")