You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.7 KiB

import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import lightgbm as lgb
import joblib
from data.data_generator import preprocess_data
def train_lightgbm_model():
"""
训练LightGBM模型用于信贷风险评估
"""
# 读取数据
print("读取信贷数据...")
df = pd.read_csv('data/credit_data.csv')
print(f"数据形状: {df.shape}")
# 数据预处理
print("数据预处理...")
X, y, scaler, le_education, le_home, le_purpose = preprocess_data(df)
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"训练集大小: {X_train.shape}")
print(f"测试集大小: {X_test.shape}")
# 创建LightGBM分类器
print("创建LightGBM模型...")
model = lgb.LGBMClassifier(
n_estimators=200,
max_depth=8,
learning_rate=0.05,
num_leaves=64,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1
)
# 训练模型
print("训练模型...")
model.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
eval_metric='binary_logloss',
callbacks=[lgb.early_stopping(10), lgb.log_evaluation(10)])
# 预测
print("模型预测...")
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred))
print("\n混淆矩阵:")
print(confusion_matrix(y_test, y_pred))
# 保存模型和预处理器
print("保存模型和预处理器...")
joblib.dump(model, 'models/lightgbm_model.pkl')
joblib.dump(scaler, 'models/scaler.pkl')
joblib.dump(le_education, 'models/le_education.pkl')
joblib.dump(le_home, 'models/le_home.pkl')
joblib.dump(le_purpose, 'models/le_purpose.pkl')
# 特征重要性
feature_importance = pd.DataFrame({
'feature': X.columns,
'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
print("\n特征重要性:")
print(feature_importance)
return model, scaler, le_education, le_home, le_purpose
if __name__ == "__main__":
model, scaler, le_education, le_home, le_purpose = train_lightgbm_model()
print("\n模型训练完成!")