import os import sys sys.path.append(os.getcwd()) from datetime import datetime from config import logger from matching.subject.itemcf_recall import init_itemcf_recall,itemcf_recall from matching.subject.item_embedding_recall import init_item_embedding_recall,item_embedding_recall from matching.subject.dssm_recall_predict import dssm_recall from matching.subject.dssm_usercf_recall import init_dssm_usercf_recall,dssm_usercf_recall from matching.subject.fm_recall_predict import fm_recall from matching.subject.mind_user0_recall_predict import mind0_recall from matching.subject.mind_user1_recall_predict import mind1_recall from matching.subject.pinsage_recall_predict import pinsage_recall from matching.subject.youtubednn_recall_predict import youtubednn_recall from matching.subject.youtube_usercf_recall import init_youtube_usercf_recall,youtube_usercf_recall from matching.subject.multi_recall_combine import combine_recall_results from matching.subject.cold_start_recall import cold_start_user_recall init_itemcf_recall() init_item_embedding_recall() init_dssm_usercf_recall() init_youtube_usercf_recall() def calc_multi_recall_ration( youtubednn_recall_dict, itemcf_recall_dict, item_embedding_recall_dict, youtube_usercf_recall_dict, cold_start_recall_dict, dssm_recall_dict, dssm_usercf_recall_dict, fm_recall_dict, mind0_recall_dict, mind1_recall_dict, pinsage_recall_dict, multi_recall_results): """ 计算每路召回的数量 """ # 取出每路召回的物品 multi_recall_item_list = [item_id for item_id, _ in list(multi_recall_results.values())[0]] youtubednn_recall_item_list = [item_id for item_id, _ in list(youtubednn_recall_dict.values())[0]] item_embedding_recall_item_list = [item_id for item_id, _ in list(item_embedding_recall_dict.values())[0]] youtube_usercf_recall_item_list = [item_id for item_id, _ in list(youtube_usercf_recall_dict.values())[0]] itemcf_recall_item_list = [item_id for item_id, _ in list(itemcf_recall_dict.values())[0]] cold_start_recall_item_list = [item_id_list[0] for item_id_list in list(cold_start_recall_dict.values())[0]] dssm_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_recall_dict.values())[0]] dssm_usercf_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_usercf_recall_dict.values())[0]] fm_recall_item_list = [item_id_list[0] for item_id_list in list(fm_recall_dict.values())[0]] mind0_recall_item_list = [item_id_list[0] for item_id_list in list(mind0_recall_dict.values())[0]] mind1_recall_item_list = [item_id_list[0] for item_id_list in list(mind1_recall_dict.values())[0]] pinsage_recall_item_list = [item_id_list[0] for item_id_list in list(pinsage_recall_dict.values())[0]] # 计算每路召回和多路召回的并集 multi_recall_count = len(multi_recall_item_list) youtubednn_recall_count = len(list(set(youtubednn_recall_item_list) & set(multi_recall_item_list))) item_embedding_recall_count = len(list(set(item_embedding_recall_item_list) & set(multi_recall_item_list))) youtube_usercf_recall_count = len(list(set(youtube_usercf_recall_item_list) & set(multi_recall_item_list))) itemcf_recall_count = len(list(set(itemcf_recall_item_list) & set(multi_recall_item_list))) cold_start_recall_count = len(list(set(cold_start_recall_item_list) & set(multi_recall_item_list))) dssm_recall_count = len(list(set(dssm_recall_item_list) & set(multi_recall_item_list))) dssm_usercf_recall_count = len(list(set(dssm_usercf_recall_item_list) & set(multi_recall_item_list))) fm_recall_count = len(list(set(fm_recall_item_list) & set(multi_recall_item_list))) mind0_recall_count = len(list(set(mind0_recall_item_list) & set(multi_recall_item_list))) mind1_recall_count = len(list(set(mind1_recall_item_list) & set(multi_recall_item_list))) pinsage_recall_count = len(list(set(pinsage_recall_item_list) & set(multi_recall_item_list))) print("多路召回数量:", multi_recall_count, "YoutubeDNN召回:", youtubednn_recall_count, "item embedding召回:",item_embedding_recall_count, "youtube usercf召回:", youtube_usercf_recall_count, "itemcf召回:", itemcf_recall_count, "冷启动召回:", cold_start_recall_count, "dssm召回:", dssm_recall_count, "dssm_usercf召回:", dssm_usercf_recall_count, "fm召回:", fm_recall_count, "mind0召回:", mind0_recall_count, "mind1召回:", mind1_recall_count, "pinsage召回:", pinsage_recall_count, ) def multi_recall_predict(user_id, disciplines_id_list=None, topk=200): """ 多路召回预测接口 """ start_time = datetime.now() only_cold_start_recall = False multi_recall_results = {} # 定义多路召回的字典,将多路召回的结果都保存在这个字典当中 user_multi_recall_dict = {'itemcf_recall': {}, 'item_embedding_recall': {}, 'dssm_recall':{}, 'dssm_usercf_recall':{}, 'fm_recall':{}, 'mind_recall0':{}, 'mind_recall1':{}, 'pinsage_recall':{}, 'youtubednn_recall': {}, 'youtubednn_usercf_recall': {}, 'cold_start_recall':{}} logger.info('开始itemcf召回...') itemcf_recall_dict = {user_id: itemcf_recall(user_id, topk // 4)} user_multi_recall_dict['itemcf_recall'] = itemcf_recall_dict logger.info('开始item embedding召回...') item_embedding_recall_dict = {user_id: item_embedding_recall(user_id, topk // 4)} user_multi_recall_dict['item_embedding_recall'] = item_embedding_recall_dict logger.info('开始dssm召回...') dssm_recall_dict = {user_id: dssm_recall(user_id, topk // 4)} user_multi_recall_dict['dssm_recall'] = dssm_recall_dict logger.info('开始dssm user embedding召回...') dssm_usercf_recall_dict = {user_id: dssm_usercf_recall(user_id, topk // 4)} user_multi_recall_dict['dssm_usercf_recall'] = dssm_usercf_recall_dict logger.info('开始fm召回...') fm_recall_dict = {user_id: fm_recall(user_id, topk // 4)} user_multi_recall_dict['fm_recall'] = fm_recall_dict logger.info('开始mind0_recall召回...') mind0_recall_dict = {user_id: mind0_recall(user_id, topk // 4)} user_multi_recall_dict['mind_recall0'] = mind0_recall_dict logger.info('开始mind1_recall召回...') mind1_recall_dict = {user_id: mind1_recall(user_id, topk // 4)} user_multi_recall_dict['mind_recall1'] = mind1_recall_dict logger.info('开始pinsage_recall召回...') pinsage_recall_dict = {user_id: pinsage_recall(user_id, topk // 4)} user_multi_recall_dict['pinsage_recall'] = pinsage_recall_dict logger.info('开始YoutubeDNN召回...') youtubednn_recall_dict = {user_id: youtubednn_recall(user_id, topk // 4)} user_multi_recall_dict['youtubednn_recall'] = youtubednn_recall_dict logger.info('开始youtube usercf召回...') youtube_usercf_recall_dict = {user_id: youtube_usercf_recall(user_id, topk // 4)} user_multi_recall_dict['youtubednn_usercf_recall'] = youtube_usercf_recall_dict logger.info('开始冷启动召回...') if (len(list(itemcf_recall_dict.values())[0]) == 0) and \ (len(list(itemcf_recall_dict.values())[0]) == 0) and \ (len(list(itemcf_recall_dict.values())[0]) == 0) and \ (len(list(itemcf_recall_dict.values())[0]) == 0): cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk) only_cold_start_recall = True multi_recall_results = cold_start_recall_dict # 计算耗时毫秒 end_time = datetime.utcnow() cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3) logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒") return multi_recall_results, only_cold_start_recall else: cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk // 10) cold_start_recall_list = [(item_id, 0.5) for item_id in cold_start_recall_dict.keys()] cold_start_recall_dict = {user_id: cold_start_recall_list} user_multi_recall_dict['cold_start_recall'] = cold_start_recall_dict # 每路召回的权重给不同的值,根据前面召回的情况调整参数的值 weight_dict = {'itemcf_recall': 0.7, 'item_embedding_recall': 1.5, 'dssm_recall':2.8, 'dssm_usercf_recall':3.9, 'fm_recall':2.1, 'mind_recall0':1.9, 'mind_recall1':1.9, 'pinsage_recall':8, 'youtubednn_recall': 2.6, 'youtubednn_usercf_recall': 3.9, 'cold_start_recall': 1} # 最终合并之后每个用户召回topk个物品进行排序 multi_recall_results = combine_recall_results(user_multi_recall_dict, weight_dict, topk=topk, save_results=False) # 计算耗时毫秒 end_time = datetime.utcnow() cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3) logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒") calc_multi_recall_ration(youtubednn_recall_dict, itemcf_recall_dict, item_embedding_recall_dict, youtube_usercf_recall_dict, cold_start_recall_dict, dssm_recall_dict, dssm_usercf_recall_dict, fm_recall_dict, mind0_recall_dict, mind1_recall_dict, pinsage_recall_dict, multi_recall_results) return multi_recall_results, only_cold_start_recall if __name__ == '__main__': multi_recall_results = multi_recall_predict(user_id=192262, topk=200) print(multi_recall_results)