You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
7.3 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import pandas as pd
from datetime import datetime
import warnings
import pickle
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
from deepctr.layers import custom_objects
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import *
from tensorflow.python.keras.models import load_model
import tensorflow as tf
from config import shixun_model_save_path
from config import logger
from config import shixun_bst_rank_dict
from config import shixun_rank_dense_fea
from config import shixun_rank_sparse_fea
from config import shixun_max_seq_len
from config import myshixuns_data_path
from config import shixun_features_save_path
from utils import get_user
from matching.shixun.recall_comm import get_item_info_df
from matching.shixun.recall_comm import get_all_select_df
from ranking.shixun.bst_ranker_train import get_bst_feats_columns
from ranking.shixun.user_recall_rank_features import init_rank_features
from ranking.shixun.user_recall_rank_features import build_rank_features_offline
global graph, sess
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
K.set_learning_phase(True)
if tf.__version__ >= '2.0.0':
tf.compat.v1.disable_eager_execution()
graph = tf.compat.v1.get_default_graph()
sess = tf.compat.v1.keras.backend.get_session()
warnings.filterwarnings('ignore')
logger.info('获取物品信息')
item_info_df = get_item_info_df()
item_info_df = item_info_df[['shixun_id', 'shixun_name']].reset_index()
logger.info('加载物品行为数据')
all_data = get_all_select_df()
logger.info('生成用户历史选择物品数据')
hist_select = all_data[['user_id', 'shixun_id']].groupby('user_id').agg({list}).reset_index()
his_behavior_df = pd.DataFrame()
his_behavior_df['user_id'] = hist_select['user_id']
his_behavior_df['hist_shixun_id'] = hist_select['shixun_id']
his_behavior_df = his_behavior_df.reset_index()
logger.info('加载BST排序模型')
bst_model = load_model(shixun_model_save_path + 'bst_model.h5', custom_objects)
def bst_ranker_predict(user_item_feats_df, topk=10, verbose=True):
"""
BST模型预测接口
:param model: 训练保存的BST模型
:param user_item_feats_df: 根据user_id生成的排序模型特征
:param topk: 返回排序后的topk个物品
"""
start_time = datetime.now()
# 获取用户的历史选择物品
user_item_feats_df = user_item_feats_df.merge(his_behavior_df, on='user_id')
# 稀疏特征
sparse_fea = shixun_rank_sparse_fea
# 行为特征
behavior_fea = ['shixun_id']
# 历史行为特征
hist_behavior_fea = ['hist_shixun_id']
# 稠密连续型特征
dense_fea = shixun_rank_dense_fea
# 填充缺失值
user_item_feats_df[dense_fea] = user_item_feats_df[dense_fea].fillna(0, )
# dense特征进行归一化
for feat in dense_fea:
min_max_scaler = pickle.load(open(shixun_model_save_path + 'min_max_scaler_' + feat + '.model', 'rb'))
user_item_feats_df[feat] = min_max_scaler.transform(user_item_feats_df[[feat]])
# sparse特征LabelEncoder
for feat in sparse_fea:
label_encoder = pickle.load(open(shixun_model_save_path + feat + '_label_encoder.model', 'rb'))
user_item_feats_df[feat] = label_encoder.transform(user_item_feats_df[[feat]])
if feat == 'shixun_id':
shixun_id_lable_encoder = label_encoder
if feat == 'user_id':
user_id_label_encoder = label_encoder
x, dnn_feature_columns = get_bst_feats_columns(user_item_feats_df,
dense_fea,
sparse_fea,
behavior_fea,
hist_behavior_fea,
max_len=shixun_max_seq_len)
# 模型预测
# with graph.as_default():
# with sess.as_default():
user_item_feats_df['pred_score'] = bst_model.predict(x, verbose=1, batch_size=256)#BST运行环境版本不兼容改成tf2.x的风格
# 还原user_id和shixun_id
user_item_feats_df['user_id'] = user_id_label_encoder.inverse_transform(user_item_feats_df[['user_id']])
user_item_feats_df['shixun_id'] = shixun_id_lable_encoder.inverse_transform(user_item_feats_df[['shixun_id']])
# 按预测分数降序排序
rank_results = user_item_feats_df[['user_id', 'shixun_id', 'rank', 'pred_score']]
rank_results['shixun_id'] = rank_results['shixun_id'].astype(int)
rank_results = rank_results.merge(item_info_df, how='left', on='shixun_id')
rank_results = rank_results[['user_id', 'shixun_id', 'shixun_name', 'rank', 'pred_score']]
rank_results.sort_values(by=['pred_score'], ascending=False, inplace=True)
rank_results['rank'] = rank_results['pred_score'].rank(ascending=False, method='first').astype(int)
rank_results = rank_results[:topk]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"BST 预测耗时: {cost_time_millisecond} 毫秒")
return rank_results
def get_user_his_behavior_df(user_id):
user_his_behavior_df = his_behavior_df[his_behavior_df['user_id'] == user_id]
return user_his_behavior_df
def alluser_bst_ranker_predict():
"""
生成所有用户召回物品离线特征排序后的字典
"""
init_rank_features()
recall_rank_list_lict = {}
all_user_ids = get_user(myshixuns_data_path)
for user_id in tqdm(all_user_ids):
user_id = int(user_id)
try:
recall_rank_list_lict.setdefault(user_id, [])
user_item_feats_df = build_rank_features_offline(user_id)
# print(user_item_feats_df)
user_his_behavior_df = get_user_his_behavior_df(user_id)
if user_his_behavior_df.shape[0] > 0:
rank_results = bst_ranker_predict(user_item_feats_df,
topk=user_item_feats_df.shape[0],
verbose=False)
for shixun_id, shixun_name in zip(rank_results['shixun_id'], rank_results['shixun_name']):
recall_rank_list_lict[user_id].append((shixun_id, shixun_name))
except:
continue
pickle.dump(recall_rank_list_lict, open(shixun_bst_rank_dict, 'wb'))
if __name__ == '__main__':
user_item_feats_df = pd.read_csv(shixun_features_save_path + 'user_item_feats_df.csv', sep='\t')
print(user_item_feats_df.columns)
tmp_user_item_feats_df = user_item_feats_df.merge(item_info_df, how='left', on='shixun_id')
logger.info('BST排序之前的数据:')
print(tmp_user_item_feats_df[['user_id', 'shixun_id', 'shixun_name', 'rank']][:20])
rank_results = bst_ranker_predict(user_item_feats_df,
topk=user_item_feats_df.shape[0])
logger.info('BST排序之后的数据:')
print(rank_results[['user_id', 'shixun_id', 'shixun_name', 'rank']][:20])
alluser_bst_ranker_predict()
print("success!!!")