You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
3.1 KiB
76 lines
3.1 KiB
5 months ago
|
import json
|
||
|
from datetime import datetime
|
||
|
from config import logger
|
||
|
from config import test_user_id
|
||
|
from ranking.shixun.user_recall_rank_features import init_rank_features
|
||
|
from ranking.shixun.user_recall_rank_features import build_rank_features_online
|
||
|
from matching.shixun.cold_start_recall import cold_start_user_recall
|
||
|
from matching.shixun.multi_recall_predict import multi_recall_predict
|
||
|
|
||
|
init_rank_features()
|
||
|
|
||
|
from ranking.shixun.xdeepfm_rank_predict import xdeepfm_ranker_predict
|
||
|
from ranking.shixun.bst_ranker_predict import bst_ranker_predict
|
||
|
from ranking.shixun.difm_ranker_predict import difm_ranker_predict
|
||
|
|
||
|
def shixun_recommend_online(user_id, disciplines_id_list=None, topk=10, rank_method='1'):
|
||
|
"""
|
||
|
根据用户ID获取推荐的实训ID列表
|
||
|
1. 使用多路召回获取用户的实训ID列表
|
||
|
2. 根据召回列表生成排序模型特征
|
||
|
3. 用排序模型对召回列表重新排序
|
||
|
4. 返回排序后的结果
|
||
|
"""
|
||
|
start_time = datetime.now()
|
||
|
|
||
|
logger.info(f"本次需要进行推荐的用户ID: {user_id}")
|
||
|
|
||
|
recommend_results = {}
|
||
|
recommend_results.clear()
|
||
|
|
||
|
# 1.先进行多路召回
|
||
|
recommend_results, only_cold_start_recall = multi_recall_predict(user_id, disciplines_id_list, topk=topk)
|
||
|
|
||
|
if only_cold_start_recall and (len(recommend_results) > 0):
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
||
|
return recommend_results
|
||
|
|
||
|
# 2.构建排序模型特征
|
||
|
user_item_feats_df = build_rank_features_online(user_id, recommend_results)
|
||
|
|
||
|
# 3.使用排序模型对召回的候选物品进行排序
|
||
|
# 如果没有召回数据则根据兴趣标签使用冷启动召回推荐
|
||
|
if user_item_feats_df.empty:
|
||
|
recommend_results = cold_start_user_recall(disciplines_id_list, topk=topk)
|
||
|
else:
|
||
|
if topk > user_item_feats_df.shape[0]:
|
||
|
topk = user_item_feats_df.shape[0]
|
||
|
|
||
|
if rank_method == '0':
|
||
|
rank_results = xdeepfm_ranker_predict(user_item_feats_df, topk=topk)
|
||
|
elif rank_method == '1':
|
||
|
rank_results = bst_ranker_predict(user_item_feats_df, topk=topk)
|
||
|
elif rank_method == '2':
|
||
|
rank_results = difm_ranker_predict(user_item_feats_df, topk=topk)
|
||
|
|
||
|
|
||
|
recommend_results = dict(zip(rank_results['shixun_id'], rank_results['shixun_name']))
|
||
|
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
||
|
return recommend_results
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
recommend_results = shixun_recommend_online(user_id=test_user_id,
|
||
|
disciplines_id_list=[],
|
||
|
topk=10,
|
||
|
rank_method='1')
|
||
|
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
|