You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
10 KiB
210 lines
10 KiB
5 months ago
|
import os
|
||
|
import sys
|
||
|
sys.path.append(os.getcwd())
|
||
|
from datetime import datetime
|
||
|
from config import logger
|
||
|
from matching.subject.itemcf_recall import init_itemcf_recall,itemcf_recall
|
||
|
from matching.subject.item_embedding_recall import init_item_embedding_recall,item_embedding_recall
|
||
|
from matching.subject.dssm_recall_predict import dssm_recall
|
||
|
from matching.subject.dssm_usercf_recall import init_dssm_usercf_recall,dssm_usercf_recall
|
||
|
from matching.subject.fm_recall_predict import fm_recall
|
||
|
from matching.subject.mind_user0_recall_predict import mind0_recall
|
||
|
from matching.subject.mind_user1_recall_predict import mind1_recall
|
||
|
from matching.subject.pinsage_recall_predict import pinsage_recall
|
||
|
from matching.subject.youtubednn_recall_predict import youtubednn_recall
|
||
|
from matching.subject.youtube_usercf_recall import init_youtube_usercf_recall,youtube_usercf_recall
|
||
|
from matching.subject.multi_recall_combine import combine_recall_results
|
||
|
from matching.subject.cold_start_recall import cold_start_user_recall
|
||
|
|
||
|
init_itemcf_recall()
|
||
|
init_item_embedding_recall()
|
||
|
init_dssm_usercf_recall()
|
||
|
init_youtube_usercf_recall()
|
||
|
|
||
|
def calc_multi_recall_ration(
|
||
|
youtubednn_recall_dict,
|
||
|
itemcf_recall_dict,
|
||
|
item_embedding_recall_dict,
|
||
|
youtube_usercf_recall_dict,
|
||
|
cold_start_recall_dict,
|
||
|
dssm_recall_dict,
|
||
|
dssm_usercf_recall_dict,
|
||
|
fm_recall_dict,
|
||
|
mind0_recall_dict,
|
||
|
mind1_recall_dict,
|
||
|
pinsage_recall_dict,
|
||
|
multi_recall_results):
|
||
|
"""
|
||
|
计算每路召回的数量
|
||
|
"""
|
||
|
|
||
|
# 取出每路召回的物品
|
||
|
multi_recall_item_list = [item_id for item_id, _ in list(multi_recall_results.values())[0]]
|
||
|
youtubednn_recall_item_list = [item_id for item_id, _ in list(youtubednn_recall_dict.values())[0]]
|
||
|
item_embedding_recall_item_list = [item_id for item_id, _ in list(item_embedding_recall_dict.values())[0]]
|
||
|
youtube_usercf_recall_item_list = [item_id for item_id, _ in list(youtube_usercf_recall_dict.values())[0]]
|
||
|
itemcf_recall_item_list = [item_id for item_id, _ in list(itemcf_recall_dict.values())[0]]
|
||
|
cold_start_recall_item_list = [item_id_list[0] for item_id_list in list(cold_start_recall_dict.values())[0]]
|
||
|
dssm_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_recall_dict.values())[0]]
|
||
|
dssm_usercf_recall_item_list = [item_id_list[0] for item_id_list in list(dssm_usercf_recall_dict.values())[0]]
|
||
|
fm_recall_item_list = [item_id_list[0] for item_id_list in list(fm_recall_dict.values())[0]]
|
||
|
mind0_recall_item_list = [item_id_list[0] for item_id_list in list(mind0_recall_dict.values())[0]]
|
||
|
mind1_recall_item_list = [item_id_list[0] for item_id_list in list(mind1_recall_dict.values())[0]]
|
||
|
pinsage_recall_item_list = [item_id_list[0] for item_id_list in list(pinsage_recall_dict.values())[0]]
|
||
|
|
||
|
# 计算每路召回和多路召回的并集
|
||
|
multi_recall_count = len(multi_recall_item_list)
|
||
|
youtubednn_recall_count = len(list(set(youtubednn_recall_item_list) & set(multi_recall_item_list)))
|
||
|
item_embedding_recall_count = len(list(set(item_embedding_recall_item_list) & set(multi_recall_item_list)))
|
||
|
youtube_usercf_recall_count = len(list(set(youtube_usercf_recall_item_list) & set(multi_recall_item_list)))
|
||
|
itemcf_recall_count = len(list(set(itemcf_recall_item_list) & set(multi_recall_item_list)))
|
||
|
cold_start_recall_count = len(list(set(cold_start_recall_item_list) & set(multi_recall_item_list)))
|
||
|
dssm_recall_count = len(list(set(dssm_recall_item_list) & set(multi_recall_item_list)))
|
||
|
dssm_usercf_recall_count = len(list(set(dssm_usercf_recall_item_list) & set(multi_recall_item_list)))
|
||
|
fm_recall_count = len(list(set(fm_recall_item_list) & set(multi_recall_item_list)))
|
||
|
mind0_recall_count = len(list(set(mind0_recall_item_list) & set(multi_recall_item_list)))
|
||
|
mind1_recall_count = len(list(set(mind1_recall_item_list) & set(multi_recall_item_list)))
|
||
|
pinsage_recall_count = len(list(set(pinsage_recall_item_list) & set(multi_recall_item_list)))
|
||
|
|
||
|
print("多路召回数量:", multi_recall_count,
|
||
|
"YoutubeDNN召回:", youtubednn_recall_count,
|
||
|
"item embedding召回:",item_embedding_recall_count,
|
||
|
"youtube usercf召回:", youtube_usercf_recall_count,
|
||
|
"itemcf召回:", itemcf_recall_count,
|
||
|
"冷启动召回:", cold_start_recall_count,
|
||
|
"dssm召回:", dssm_recall_count,
|
||
|
"dssm_usercf召回:", dssm_usercf_recall_count,
|
||
|
"fm召回:", fm_recall_count,
|
||
|
"mind0召回:", mind0_recall_count,
|
||
|
"mind1召回:", mind1_recall_count,
|
||
|
"pinsage召回:", pinsage_recall_count,
|
||
|
)
|
||
|
|
||
|
|
||
|
def multi_recall_predict(user_id, disciplines_id_list=None, topk=200):
|
||
|
"""
|
||
|
多路召回预测接口
|
||
|
"""
|
||
|
start_time = datetime.now()
|
||
|
|
||
|
only_cold_start_recall = False
|
||
|
multi_recall_results = {}
|
||
|
|
||
|
# 定义多路召回的字典,将多路召回的结果都保存在这个字典当中
|
||
|
user_multi_recall_dict = {'itemcf_recall': {},
|
||
|
'item_embedding_recall': {},
|
||
|
'dssm_recall':{},
|
||
|
'dssm_usercf_recall':{},
|
||
|
'fm_recall':{},
|
||
|
'mind_recall0':{},
|
||
|
'mind_recall1':{},
|
||
|
'pinsage_recall':{},
|
||
|
'youtubednn_recall': {},
|
||
|
'youtubednn_usercf_recall': {},
|
||
|
'cold_start_recall':{}}
|
||
|
|
||
|
logger.info('开始itemcf召回...')
|
||
|
itemcf_recall_dict = {user_id: itemcf_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['itemcf_recall'] = itemcf_recall_dict
|
||
|
logger.info('开始item embedding召回...')
|
||
|
item_embedding_recall_dict = {user_id: item_embedding_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['item_embedding_recall'] = item_embedding_recall_dict
|
||
|
|
||
|
logger.info('开始dssm召回...')
|
||
|
dssm_recall_dict = {user_id: dssm_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['dssm_recall'] = dssm_recall_dict
|
||
|
|
||
|
logger.info('开始dssm user embedding召回...')
|
||
|
dssm_usercf_recall_dict = {user_id: dssm_usercf_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['dssm_usercf_recall'] = dssm_usercf_recall_dict
|
||
|
|
||
|
logger.info('开始fm召回...')
|
||
|
fm_recall_dict = {user_id: fm_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['fm_recall'] = fm_recall_dict
|
||
|
|
||
|
logger.info('开始mind0_recall召回...')
|
||
|
mind0_recall_dict = {user_id: mind0_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['mind_recall0'] = mind0_recall_dict
|
||
|
|
||
|
logger.info('开始mind1_recall召回...')
|
||
|
mind1_recall_dict = {user_id: mind1_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['mind_recall1'] = mind1_recall_dict
|
||
|
|
||
|
logger.info('开始pinsage_recall召回...')
|
||
|
pinsage_recall_dict = {user_id: pinsage_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['pinsage_recall'] = pinsage_recall_dict
|
||
|
|
||
|
logger.info('开始YoutubeDNN召回...')
|
||
|
youtubednn_recall_dict = {user_id: youtubednn_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['youtubednn_recall'] = youtubednn_recall_dict
|
||
|
|
||
|
logger.info('开始youtube usercf召回...')
|
||
|
youtube_usercf_recall_dict = {user_id: youtube_usercf_recall(user_id, topk // 4)}
|
||
|
user_multi_recall_dict['youtubednn_usercf_recall'] = youtube_usercf_recall_dict
|
||
|
|
||
|
logger.info('开始冷启动召回...')
|
||
|
if (len(list(itemcf_recall_dict.values())[0]) == 0) and \
|
||
|
(len(list(itemcf_recall_dict.values())[0]) == 0) and \
|
||
|
(len(list(itemcf_recall_dict.values())[0]) == 0) and \
|
||
|
(len(list(itemcf_recall_dict.values())[0]) == 0):
|
||
|
cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk)
|
||
|
|
||
|
only_cold_start_recall = True
|
||
|
multi_recall_results = cold_start_recall_dict
|
||
|
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒")
|
||
|
return multi_recall_results, only_cold_start_recall
|
||
|
else:
|
||
|
cold_start_recall_dict = cold_start_user_recall(user_id, disciplines_id_list, topk // 10)
|
||
|
|
||
|
cold_start_recall_list = [(item_id, 0.5) for item_id in cold_start_recall_dict.keys()]
|
||
|
cold_start_recall_dict = {user_id: cold_start_recall_list}
|
||
|
user_multi_recall_dict['cold_start_recall'] = cold_start_recall_dict
|
||
|
|
||
|
# 每路召回的权重给不同的值,根据前面召回的情况调整参数的值
|
||
|
weight_dict = {'itemcf_recall': 0.7,
|
||
|
'item_embedding_recall': 1.5,
|
||
|
'dssm_recall':2.8,
|
||
|
'dssm_usercf_recall':3.9,
|
||
|
'fm_recall':2.1,
|
||
|
'mind_recall0':1.9,
|
||
|
'mind_recall1':1.9,
|
||
|
'pinsage_recall':8,
|
||
|
'youtubednn_recall': 2.6,
|
||
|
'youtubednn_usercf_recall': 3.9,
|
||
|
'cold_start_recall': 1}
|
||
|
|
||
|
|
||
|
|
||
|
# 最终合并之后每个用户召回topk个物品进行排序
|
||
|
multi_recall_results = combine_recall_results(user_multi_recall_dict, weight_dict, topk=topk, save_results=False)
|
||
|
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"多路召回总计耗时: {cost_time_millisecond} 毫秒")
|
||
|
|
||
|
calc_multi_recall_ration(youtubednn_recall_dict,
|
||
|
itemcf_recall_dict,
|
||
|
item_embedding_recall_dict,
|
||
|
youtube_usercf_recall_dict,
|
||
|
cold_start_recall_dict,
|
||
|
dssm_recall_dict,
|
||
|
dssm_usercf_recall_dict,
|
||
|
fm_recall_dict,
|
||
|
mind0_recall_dict,
|
||
|
mind1_recall_dict,
|
||
|
pinsage_recall_dict,
|
||
|
multi_recall_results)
|
||
|
|
||
|
return multi_recall_results, only_cold_start_recall
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
multi_recall_results = multi_recall_predict(user_id=192262, topk=200)
|
||
|
print(multi_recall_results)
|