You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
6.0 KiB
142 lines
6.0 KiB
from tqdm import tqdm
|
|
import os
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
import warnings
|
|
import pickle
|
|
from config import logger
|
|
from config import shixun_item_embedding_recall_dict
|
|
from config import shixun_itemcf_recall_dict
|
|
from config import shixun_dssm_recall_dict
|
|
from config import shixun_dssm_usercf_recall_dict
|
|
from config import shixun_fm_recall_dict
|
|
from config import shixun_mind_recall0_dict
|
|
from config import shixun_mind_recall1_dict
|
|
from config import shixun_pinsage_recall_dict
|
|
from config import shixun_youtubednn_recall_dict
|
|
from config import shixun_youtubednn_usercf_recall_dict
|
|
from config import shixun_final_recall_items_dict
|
|
|
|
tqdm.pandas()
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
def combine_recall_results(user_multi_recall_dict, weight_dict=None, topk=100, save_results=True):
|
|
"""
|
|
多路召回合并,将所有的召回策略得到的物品列表合并起来
|
|
"""
|
|
final_recall_items_dict = {}
|
|
|
|
# 对每一种召回结果按照用户进行归一化,方便后面多种召回结果,相同用户的物品之间权重相加
|
|
def norm_user_recall_items_sim(sorted_item_list):
|
|
|
|
if len(sorted_item_list) < 2:
|
|
return sorted_item_list
|
|
|
|
min_sim = sorted_item_list[-1][1]
|
|
max_sim = sorted_item_list[0][1]
|
|
|
|
norm_sorted_item_list = []
|
|
for item, score in sorted_item_list:
|
|
if max_sim > 0:
|
|
norm_score = 1.0 * (score - min_sim) / (max_sim - min_sim) if max_sim > min_sim else 1.0
|
|
else:
|
|
norm_score = 0.0
|
|
norm_sorted_item_list.append((item, norm_score))
|
|
|
|
return norm_sorted_item_list
|
|
|
|
logger.info('开始多路召回合并...')
|
|
|
|
for method, user_recall_items in tqdm(user_multi_recall_dict.items()):
|
|
|
|
# 在计算最终召回结果的时候,也可以为每一种召回结果设置一个权重
|
|
if weight_dict == None:
|
|
recall_method_weight = 1
|
|
else:
|
|
recall_method_weight = weight_dict[method]
|
|
|
|
for user_id, sorted_item_list in user_recall_items.items():
|
|
# 进行归一化
|
|
user_recall_items[user_id] = norm_user_recall_items_sim(sorted_item_list)
|
|
|
|
for user_id, sorted_item_list in user_recall_items.items():
|
|
final_recall_items_dict.setdefault(user_id, {})
|
|
|
|
for item, score in sorted_item_list:
|
|
final_recall_items_dict[user_id].setdefault(item, 0)
|
|
final_recall_items_dict[user_id][item] += recall_method_weight * score
|
|
|
|
final_recall_items_dict_rank = {}
|
|
|
|
# 多路召回时可以控制最终的召回数量
|
|
for user, recall_item_dict in final_recall_items_dict.items():
|
|
final_recall_items_dict_rank[user] = sorted(recall_item_dict.items(),
|
|
key = lambda x: x[1], reverse = True)[:topk]
|
|
|
|
if save_results:
|
|
logger.info('保存多路召回合并的结果')
|
|
pickle.dump(final_recall_items_dict_rank, open(shixun_final_recall_items_dict, 'wb'))
|
|
|
|
return final_recall_items_dict_rank
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# 定义多路召回的字典,将多路召回的结果都保存在这个字典当中
|
|
user_multi_recall_dict = {'itemcf_recall': {},
|
|
'item_embedding_recall': {},
|
|
'dssm_recall':{},
|
|
'dssm_usercf_recall':{},
|
|
'fm_recall':{},
|
|
'mind_recall0':{},
|
|
'mind_recall1':{},
|
|
'pinsage_recall':{},
|
|
'youtubednn_recall': {},
|
|
'youtubednn_usercf_recall': {}
|
|
}
|
|
|
|
logger.info('加载itemcf召回结果')
|
|
user_multi_recall_dict['itemcf_recall'] = pickle.load(open(shixun_itemcf_recall_dict, 'rb'))
|
|
|
|
logger.info('加载item embedding召回结果')
|
|
user_multi_recall_dict['item_embedding_recall'] = pickle.load(open(shixun_item_embedding_recall_dict, 'rb'))
|
|
|
|
logger.info('加载dssm召回结果')
|
|
user_multi_recall_dict['dssm_recall'] = pickle.load(open(shixun_dssm_recall_dict, 'rb'))
|
|
|
|
logger.info('加载dssm usercf召回结果')
|
|
user_multi_recall_dict['dssm_usercf_recall'] = pickle.load(open(shixun_dssm_usercf_recall_dict, 'rb'))
|
|
|
|
logger.info('加载fm召回结果')
|
|
user_multi_recall_dict['fm_recall'] = pickle.load(open(shixun_fm_recall_dict, 'rb'))
|
|
|
|
logger.info('加载mind 用户兴趣1召回结果')
|
|
user_multi_recall_dict['mind_recall0'] = pickle.load(open(shixun_mind_recall0_dict, 'rb'))
|
|
|
|
logger.info('加载mind 用户兴趣1召回结果')
|
|
user_multi_recall_dict['mind_recall1'] = pickle.load(open(shixun_mind_recall1_dict, 'rb'))
|
|
|
|
logger.info('加载pinsage召回结果')
|
|
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_pinsage_recall_dict, 'rb'))
|
|
|
|
logger.info('加载YoutubeDNN召回结果')
|
|
user_multi_recall_dict['youtubednn_recall'] = pickle.load(open(shixun_youtubednn_recall_dict, 'rb'))
|
|
|
|
logger.info('加载youtube usercf召回结果')
|
|
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_youtubednn_usercf_recall_dict, 'rb'))
|
|
|
|
# 每路召回的权重给不同的值,根据前面召回的情况调整参数的值
|
|
weight_dict = {'itemcf_recall': 0.7,
|
|
'item_embedding_recall': 10,
|
|
'dssm_recall':1.2,
|
|
'dssm_usercf_recall':12,
|
|
'fm_recall':1,
|
|
'mind_recall0':0.8,
|
|
'mind_recall1':0.7,
|
|
'pinsage_recall':60,
|
|
'youtubednn_recall': 1,
|
|
'youtubednn_usercf_recall': 12}
|
|
|
|
# 最终合并之后每个用户召回300个物品进行排序
|
|
final_recall_items_dict_rank = combine_recall_results(user_multi_recall_dict, weight_dict, topk=300) |