import pandas as pd import numpy as np import joblib import shap import matplotlib matplotlib.use('Agg') # 使用非交互式后端 import matplotlib.pyplot as plt import os # 添加项目根目录到Python路径 import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.data_generator import preprocess_data def explain_with_shap(): """ 使用SHAP解释LightGBM模型 """ # 读取数据 print("读取信贷数据...") df = pd.read_csv('data/credit_data.csv') # 数据预处理 print("数据预处理...") X, y, scaler, le_education, le_home, le_purpose = preprocess_data(df) # 加载训练好的模型 print("加载模型...") try: model = joblib.load('models/lightgbm_model.pkl') model_name = "LightGBM" except: # 如果LightGBM模型不存在,回退到XGBoost模型 model = joblib.load('models/xgboost_model.pkl') model_name = "XGBoost" print(f"使用{model_name}模型") # 选择一小部分数据进行解释(避免计算时间过长) X_sample = X.iloc[:100] # 增加样本数量以获得更准确的解释 # 创建SHAP解释器 print("创建SHAP解释器...") explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X_sample) # 绘制特征重要性条形图 print("绘制SHAP特征重要性图...") plt.figure(figsize=(10, 6)) shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False) plt.title(f'{model_name}模型SHAP特征重要性') plt.tight_layout() plt.savefig('visualization/shap_feature_importance.png', dpi=300, bbox_inches='tight') plt.close() # 绘制SHAP摘要图 print("绘制SHAP摘要图...") plt.figure(figsize=(10, 8)) shap.summary_plot(shap_values, X_sample, show=False) plt.title(f'{model_name}模型SHAP摘要图') plt.tight_layout() plt.savefig('visualization/shap_summary.png', dpi=300, bbox_inches='tight') plt.close() # 保存SHAP解释器以便在API中使用 joblib.dump(explainer, 'models/shap_explainer.pkl') print("SHAP解释完成,图表已保存到 visualization 目录") return shap_values, X_sample if __name__ == "__main__": shap_values, X_sample = explain_with_shap() print("SHAP解释模块集成完成!")