You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
6.1 KiB

import os
import sys
sys.path.append(os.getcwd())
import pandas as pd
from datetime import datetime
import warnings
from tqdm import tqdm
import pickle
from matching.shixun.recall_comm import get_item_info_df
from deepctr.layers import custom_objects
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import *
from tensorflow.python.keras.models import load_model
import tensorflow as tf
from config import shixun_model_save_path
from config import logger
from config import shixun_all_user_item_feats
from config import shixun_xdeepfm_rank_dict
from config import shixun_rank_dense_fea
from config import shixun_rank_sparse_fea
from config import myshixuns_data_path
from config import shixun_features_save_path
from utils import get_user
from ranking.shixun.xdeepfm_ranker_train import get_xdeepfm_feats_columns
from ranking.shixun.user_recall_rank_features import init_rank_features
from ranking.shixun.user_recall_rank_features import build_rank_features_offline
global graph, sess
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
K.set_learning_phase(True)
if tf.__version__ >= '2.0.0':
tf.compat.v1.disable_eager_execution()
graph = tf.compat.v1.get_default_graph()
sess = tf.compat.v1.keras.backend.get_session()
warnings.filterwarnings('ignore')
logger.info('获取物品信息')
item_info_df = get_item_info_df()
item_info_df = item_info_df[['shixun_id', 'shixun_name']].reset_index()
logger.info('加载xDeepFM排序模型')
xdeepfm_model = load_model(shixun_model_save_path + 'xdeepfm_model.h5',custom_objects)
# xdeepfm_model.summary()
def xdeepfm_ranker_predict(user_item_feats_df, topk=10, verbose=True):
"""
xDeepFM模型预测接口
:param user_item_feats_df: 根据user_id生成的排序模型特征
:param topk: 返回排序后的topk个物品
"""
start_time = datetime.now()
# 稀疏特征
sparse_fea = shixun_rank_sparse_fea
# 稠密连续型特征
dense_fea = shixun_rank_dense_fea
# 填充缺失值
user_item_feats_df[dense_fea] = user_item_feats_df[dense_fea].fillna(0, )
# dense特征进行归一化
for feat in dense_fea:
min_max_scaler = pickle.load(open(shixun_model_save_path + 'min_max_scaler_' + feat + '.model', 'rb'))
user_item_feats_df[feat] = min_max_scaler.transform(user_item_feats_df[[feat]])
# sparse特征LabelEncoder
for feat in sparse_fea:
label_encoder = pickle.load(open(shixun_model_save_path + feat + '_label_encoder.model', 'rb'))
user_item_feats_df[feat] = label_encoder.transform(user_item_feats_df[[feat]])
if feat == 'shixun_id':
shixun_id_lable_encoder = label_encoder
if feat == 'user_id':
user_id_label_encoder = label_encoder
x, linear_feature_columns, dnn_feature_columns = get_xdeepfm_feats_columns(
user_item_feats_df, dense_fea, sparse_fea)
# 模型预测
with graph.as_default():
with sess.as_default():
user_item_feats_df['pred_score'] = xdeepfm_model.predict(x, verbose=1, batch_size=256)
# 还原user_id和shixun_id
user_item_feats_df['user_id'] = user_id_label_encoder.inverse_transform(user_item_feats_df[['user_id']])
user_item_feats_df['shixun_id'] = shixun_id_lable_encoder.inverse_transform(user_item_feats_df[['shixun_id']])
# 按预测分数降序排序
rank_results = user_item_feats_df[['user_id', 'shixun_id', 'pred_score']]
rank_results['user_id'] = rank_results['user_id'].astype(int)
rank_results['shixun_id'] = rank_results['shixun_id'].astype(int)
rank_results = rank_results.merge(item_info_df, how='left', on='shixun_id')
rank_results = rank_results[['user_id', 'shixun_id', 'shixun_name', 'pred_score']]
rank_results.sort_values(by=['pred_score'], ascending=False, inplace=True)
rank_results['pred_rank'] = rank_results['pred_score'].rank(ascending=False, method='first').astype(int)
rank_results = rank_results[:topk]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"xDeepFM 预测耗时: {cost_time_millisecond} 毫秒")
return rank_results
def alluser_xdeepfm_ranker_predict():
"""
生成所有用户召回物品离线特征排序后的字典
"""
init_rank_features()
all_user_item_feats_df = pd.read_csv(shixun_all_user_item_feats, sep='\t', encoding='utf-8')
recall_rank_list_lict = {}
# all_user_ids = all_user_item_feats_df['user_id'].unique()
all_user_ids = get_user(myshixuns_data_path)
for user_id in tqdm(all_user_ids):
user_id = int(user_id)
recall_rank_list_lict.setdefault(user_id, [])
user_item_feats_df = build_rank_features_offline(user_id)
if user_item_feats_df.shape[0] == 0:
continue
rank_results = xdeepfm_ranker_predict(user_item_feats_df,
topk=user_item_feats_df.shape[0],
verbose=False)
for shixun_id, shixun_name in zip(rank_results['shixun_id'], rank_results['shixun_name']):
recall_rank_list_lict[user_id].append((shixun_id, shixun_name))
pickle.dump(recall_rank_list_lict, open(shixun_xdeepfm_rank_dict, 'wb'))
if __name__ == '__main__':
user_item_feats_df = pd.read_csv(shixun_features_save_path + 'user_item_feats_df.csv', sep='\t')
tmp_user_item_feats_df = user_item_feats_df.merge(item_info_df, how='left', on='shixun_id')
logger.info('xDeepFM排序之前的数据:')
print(tmp_user_item_feats_df[['user_id', 'shixun_id', 'shixun_name', 'score', 'rank']][:20])
rank_results = xdeepfm_ranker_predict(user_item_feats_df, topk=user_item_feats_df.shape[0])
logger.info('xDeepFM排序之后的数据:')
print(rank_results[['user_id', 'shixun_id', 'shixun_name', 'pred_score', 'pred_rank']][:20])
alluser_xdeepfm_ranker_predict()
print("success!!!")