|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
import shap
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg') # 使用非交互式后端
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
|
|
|
# 添加项目根目录到Python路径
|
|
|
import sys
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
from data.data_generator import preprocess_data
|
|
|
|
|
|
def explain_with_shap():
|
|
|
"""
|
|
|
使用SHAP解释LightGBM模型
|
|
|
"""
|
|
|
# 读取数据
|
|
|
print("读取信贷数据...")
|
|
|
df = pd.read_csv('data/credit_data.csv')
|
|
|
|
|
|
# 数据预处理
|
|
|
print("数据预处理...")
|
|
|
X, y, scaler, le_education, le_home, le_purpose = preprocess_data(df)
|
|
|
|
|
|
# 加载训练好的模型
|
|
|
print("加载模型...")
|
|
|
try:
|
|
|
model = joblib.load('models/lightgbm_model.pkl')
|
|
|
model_name = "LightGBM"
|
|
|
except:
|
|
|
# 如果LightGBM模型不存在,回退到XGBoost模型
|
|
|
model = joblib.load('models/xgboost_model.pkl')
|
|
|
model_name = "XGBoost"
|
|
|
|
|
|
print(f"使用{model_name}模型")
|
|
|
|
|
|
# 选择一小部分数据进行解释(避免计算时间过长)
|
|
|
X_sample = X.iloc[:100] # 增加样本数量以获得更准确的解释
|
|
|
|
|
|
# 创建SHAP解释器
|
|
|
print("创建SHAP解释器...")
|
|
|
explainer = shap.TreeExplainer(model)
|
|
|
shap_values = explainer.shap_values(X_sample)
|
|
|
|
|
|
# 绘制特征重要性条形图
|
|
|
print("绘制SHAP特征重要性图...")
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
|
|
|
plt.title(f'{model_name}模型SHAP特征重要性')
|
|
|
plt.tight_layout()
|
|
|
plt.savefig('visualization/shap_feature_importance.png', dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
# 绘制SHAP摘要图
|
|
|
print("绘制SHAP摘要图...")
|
|
|
plt.figure(figsize=(10, 8))
|
|
|
shap.summary_plot(shap_values, X_sample, show=False)
|
|
|
plt.title(f'{model_name}模型SHAP摘要图')
|
|
|
plt.tight_layout()
|
|
|
plt.savefig('visualization/shap_summary.png', dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
# 保存SHAP解释器以便在API中使用
|
|
|
joblib.dump(explainer, 'models/shap_explainer.pkl')
|
|
|
|
|
|
print("SHAP解释完成,图表已保存到 visualization 目录")
|
|
|
|
|
|
return shap_values, X_sample
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
shap_values, X_sample = explain_with_shap()
|
|
|
print("SHAP解释模块集成完成!") |