You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
10 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
from datetime import datetime
from config import logger
from matching.subject.itemcf_recall import init_itemcf_recall,itemcf_recall
from matching.subject.item_embedding_recall import init_item_embedding_recall,item_embedding_recall
from matching.subject.dssm_recall_predict import dssm_recall
from matching.subject.dssm_usercf_recall import init_dssm_usercf_recall,dssm_usercf_recall
from matching.subject.fm_recall_predict import fm_recall
from matching.subject.mind_user0_recall_predict import mind0_recall
from matching.subject.mind_user1_recall_predict import mind1_recall
from matching.subject.pinsage_recall_predict import pinsage_recall
from matching.subject.youtubednn_recall_predict import youtubednn_recall
from matching.subject.youtube_usercf_recall import init_youtube_usercf_recall,youtube_usercf_recall
from matching.subject.multi_recall_combine import combine_recall_results
from matching.subject.cold_start_recall import cold_start_user_recall
init_itemcf_recall()
init_item_embedding_recall()
init_dssm_usercf_recall()
init_youtube_usercf_recall()
def calc_multi_recall_ration(
youtubednn_recall_dict,
itemcf_recall_dict,
item_embedding_recall_dict,
youtube_usercf_recall_dict,
cold_start_recall_dict,
dssm_recall_dict,
dssm_usercf_recall_dict,
fm_recall_dict,
mind0_recall_dict,
mind1_recall_dict,
pinsage_recall_dict,
multi_recall_results):
"""
计算每路召回的数量
"""
# 取出每路召回的物品
multi_recall_item_list = [item_id for item_id, _ in list(multi_recall_results.values())[0]]
youtubednn_recall_item_list = [item_id for item_id, _ in list(youtubednn_recall_dict.values())[0]]
item_embedding_recall_item_list = [item_id for item_id, _ in list(item_embedding_recall_dict.values())[0]]
youtube_usercf_recall_item_list = [item_id for item_id, _ in list(youtube_usercf_recall_dict.values())[0]]
itemcf_recall_item_list = [item_id for item_id, _ in list(itemcf_recall_dict.values())[0]]
cold_start_recall_item_list = [item_id_list[0] for item_id_list in list(cold_start_recall_dict.values())[0]]
dssm_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_recall_dict.values())[0]]
dssm_usercf_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_usercf_recall_dict.values())[0]]
fm_recall_item_list = [item_id_list[0] for item_id_list in list(fm_recall_dict.values())[0]]
mind0_recall_item_list = [item_id_list[0] for item_id_list in list(mind0_recall_dict.values())[0]]
mind1_recall_item_list = [item_id_list[0] for item_id_list in list(mind1_recall_dict.values())[0]]
pinsage_recall_item_list = [item_id_list[0] for item_id_list in list(pinsage_recall_dict.values())[0]]
# 计算每路召回和多路召回的并集
multi_recall_count = len(multi_recall_item_list)
youtubednn_recall_count = len(list(set(youtubednn_recall_item_list) & set(multi_recall_item_list)))
item_embedding_recall_count = len(list(set(item_embedding_recall_item_list) & set(multi_recall_item_list)))
youtube_usercf_recall_count = len(list(set(youtube_usercf_recall_item_list) & set(multi_recall_item_list)))
itemcf_recall_count = len(list(set(itemcf_recall_item_list) & set(multi_recall_item_list)))
cold_start_recall_count = len(list(set(cold_start_recall_item_list) & set(multi_recall_item_list)))
dssm_recall_count = len(list(set(dssm_recall_item_list) & set(multi_recall_item_list)))
dssm_usercf_recall_count = len(list(set(dssm_usercf_recall_item_list) & set(multi_recall_item_list)))
fm_recall_count = len(list(set(fm_recall_item_list) & set(multi_recall_item_list)))
mind0_recall_count = len(list(set(mind0_recall_item_list) & set(multi_recall_item_list)))
mind1_recall_count = len(list(set(mind1_recall_item_list) & set(multi_recall_item_list)))
pinsage_recall_count = len(list(set(pinsage_recall_item_list) & set(multi_recall_item_list)))
print("多路召回数量:", multi_recall_count,
"YoutubeDNN召回:", youtubednn_recall_count,
"item embedding召回:",item_embedding_recall_count,
"youtube usercf召回:", youtube_usercf_recall_count,
"itemcf召回:", itemcf_recall_count,
"冷启动召回:", cold_start_recall_count,
"dssm召回:", dssm_recall_count,
"dssm_usercf召回:", dssm_usercf_recall_count,
"fm召回:", fm_recall_count,
"mind0召回:", mind0_recall_count,
"mind1召回:", mind1_recall_count,
"pinsage召回:", pinsage_recall_count,
)
def multi_recall_predict(user_id, disciplines_id_list=None, topk=200):
"""
多路召回预测接口
"""
start_time = datetime.now()
only_cold_start_recall = False
multi_recall_results = {}
# 定义多路召回的字典,将多路召回的结果都保存在这个字典当中
user_multi_recall_dict = {'itemcf_recall': {},
'item_embedding_recall': {},
'dssm_recall':{},
'dssm_usercf_recall':{},
'fm_recall':{},
'mind_recall0':{},
'mind_recall1':{},
'pinsage_recall':{},
'youtubednn_recall': {},
'youtubednn_usercf_recall': {},
'cold_start_recall':{}}
logger.info('开始itemcf召回...')
itemcf_recall_dict = {user_id: itemcf_recall(user_id, topk // 4)}
user_multi_recall_dict['itemcf_recall'] = itemcf_recall_dict
logger.info('开始item embedding召回...')
item_embedding_recall_dict = {user_id: item_embedding_recall(user_id, topk // 4)}
user_multi_recall_dict['item_embedding_recall'] = item_embedding_recall_dict
logger.info('开始dssm召回...')
dssm_recall_dict = {user_id: dssm_recall(user_id, topk // 4)}
user_multi_recall_dict['dssm_recall'] = dssm_recall_dict
logger.info('开始dssm user embedding召回...')
dssm_usercf_recall_dict = {user_id: dssm_usercf_recall(user_id, topk // 4)}
user_multi_recall_dict['dssm_usercf_recall'] = dssm_usercf_recall_dict
logger.info('开始fm召回...')
fm_recall_dict = {user_id: fm_recall(user_id, topk // 4)}
user_multi_recall_dict['fm_recall'] = fm_recall_dict
logger.info('开始mind0_recall召回...')
mind0_recall_dict = {user_id: mind0_recall(user_id, topk // 4)}
user_multi_recall_dict['mind_recall0'] = mind0_recall_dict
logger.info('开始mind1_recall召回...')
mind1_recall_dict = {user_id: mind1_recall(user_id, topk // 4)}
user_multi_recall_dict['mind_recall1'] = mind1_recall_dict
logger.info('开始pinsage_recall召回...')
pinsage_recall_dict = {user_id: pinsage_recall(user_id, topk // 4)}
user_multi_recall_dict['pinsage_recall'] = pinsage_recall_dict
logger.info('开始YoutubeDNN召回...')
youtubednn_recall_dict = {user_id: youtubednn_recall(user_id, topk // 4)}
user_multi_recall_dict['youtubednn_recall'] = youtubednn_recall_dict
logger.info('开始youtube usercf召回...')
youtube_usercf_recall_dict = {user_id: youtube_usercf_recall(user_id, topk // 4)}
user_multi_recall_dict['youtubednn_usercf_recall'] = youtube_usercf_recall_dict
logger.info('开始冷启动召回...')
if (len(list(itemcf_recall_dict.values())[0]) == 0) and \
(len(list(itemcf_recall_dict.values())[0]) == 0) and \
(len(list(itemcf_recall_dict.values())[0]) == 0) and \
(len(list(itemcf_recall_dict.values())[0]) == 0):
cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk)
only_cold_start_recall = True
multi_recall_results = cold_start_recall_dict
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒")
return multi_recall_results, only_cold_start_recall
else:
cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk // 10)
cold_start_recall_list = [(item_id, 0.5) for item_id in cold_start_recall_dict.keys()]
cold_start_recall_dict = {user_id: cold_start_recall_list}
user_multi_recall_dict['cold_start_recall'] = cold_start_recall_dict
# 每路召回的权重给不同的值,根据前面召回的情况调整参数的值
weight_dict = {'itemcf_recall': 0.7,
'item_embedding_recall': 1.5,
'dssm_recall':2.8,
'dssm_usercf_recall':3.9,
'fm_recall':2.1,
'mind_recall0':1.9,
'mind_recall1':1.9,
'pinsage_recall':8,
'youtubednn_recall': 2.6,
'youtubednn_usercf_recall': 3.9,
'cold_start_recall': 1}
# 最终合并之后每个用户召回topk个物品进行排序
multi_recall_results = combine_recall_results(user_multi_recall_dict, weight_dict, topk=topk, save_results=False)
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒")
calc_multi_recall_ration(youtubednn_recall_dict,
itemcf_recall_dict,
item_embedding_recall_dict,
youtube_usercf_recall_dict,
cold_start_recall_dict,
dssm_recall_dict,
dssm_usercf_recall_dict,
fm_recall_dict,
mind0_recall_dict,
mind1_recall_dict,
pinsage_recall_dict,
multi_recall_results)
return multi_recall_results, only_cold_start_recall
if __name__ == '__main__':
multi_recall_results = multi_recall_predict(user_id=192262, topk=200)
print(multi_recall_results)