You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

31 lines
803 B

from sklearn.ensemble import RandomForestRegressor
import joblib
from sklearn.metrics import mean_absolute_error
import ProcessData
# 训练并保存模型
def getModel(a="Model.pkl"):
"""
:param a: 模型文件名
:return:
[socre: MAE评估结果,
X_test: 预测数据集]
"""
# 获取测试集、训练集、验证集
[X_train, X_valid, y_train, y_valid, X_test] = ProcessData.ProcessData()
# 随机树森林模型
model = RandomForestRegressor(random_state=0, n_estimators=1001)
# 训练模型
model.fit(X_train, y_train)
# 预测模型
preds = model.predict(X_valid)
# 用MAE评估
score = mean_absolute_error(y_valid, preds)
# 保存模型到本地
joblib.dump(model, a)
# 返回MAE
return [score, X_test]