You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
2.7 KiB
93 lines
2.7 KiB
import sys
|
|
import os
|
|
|
|
# 添加项目根目录到Python路径
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
|
import lightgbm as lgb
|
|
import joblib
|
|
|
|
from data.data_generator import preprocess_data
|
|
|
|
def train_lightgbm_model():
|
|
"""
|
|
训练LightGBM模型用于信贷风险评估
|
|
"""
|
|
# 读取数据
|
|
print("读取信贷数据...")
|
|
df = pd.read_csv('data/credit_data.csv')
|
|
print(f"数据形状: {df.shape}")
|
|
|
|
# 数据预处理
|
|
print("数据预处理...")
|
|
X, y, scaler, le_education, le_home, le_purpose = preprocess_data(df)
|
|
|
|
# 分割数据集
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.2, random_state=42, stratify=y
|
|
)
|
|
|
|
print(f"训练集大小: {X_train.shape}")
|
|
print(f"测试集大小: {X_test.shape}")
|
|
|
|
# 创建LightGBM分类器
|
|
print("创建LightGBM模型...")
|
|
model = lgb.LGBMClassifier(
|
|
n_estimators=200,
|
|
max_depth=8,
|
|
learning_rate=0.05,
|
|
num_leaves=64,
|
|
subsample=0.8,
|
|
colsample_bytree=0.8,
|
|
random_state=42,
|
|
verbose=-1
|
|
)
|
|
|
|
# 训练模型
|
|
print("训练模型...")
|
|
model.fit(X_train, y_train,
|
|
eval_set=[(X_test, y_test)],
|
|
eval_metric='binary_logloss',
|
|
callbacks=[lgb.early_stopping(10), lgb.log_evaluation(10)])
|
|
|
|
# 预测
|
|
print("模型预测...")
|
|
y_pred = model.predict(X_test)
|
|
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
|
|
|
# 评估模型
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
print(f"模型准确率: {accuracy:.4f}")
|
|
|
|
print("\n分类报告:")
|
|
print(classification_report(y_test, y_pred))
|
|
|
|
print("\n混淆矩阵:")
|
|
print(confusion_matrix(y_test, y_pred))
|
|
|
|
# 保存模型和预处理器
|
|
print("保存模型和预处理器...")
|
|
joblib.dump(model, 'models/lightgbm_model.pkl')
|
|
joblib.dump(scaler, 'models/scaler.pkl')
|
|
joblib.dump(le_education, 'models/le_education.pkl')
|
|
joblib.dump(le_home, 'models/le_home.pkl')
|
|
joblib.dump(le_purpose, 'models/le_purpose.pkl')
|
|
|
|
# 特征重要性
|
|
feature_importance = pd.DataFrame({
|
|
'feature': X.columns,
|
|
'importance': model.feature_importances_
|
|
}).sort_values('importance', ascending=False)
|
|
|
|
print("\n特征重要性:")
|
|
print(feature_importance)
|
|
|
|
return model, scaler, le_education, le_home, le_purpose
|
|
|
|
if __name__ == "__main__":
|
|
model, scaler, le_education, le_home, le_purpose = train_lightgbm_model()
|
|
print("\n模型训练完成!") |