from tqdm import tqdm import os import sys sys.path.append(os.getcwd()) import warnings import pickle from config import logger from config import subject_item_embedding_recall_dict from config import subject_itemcf_recall_dict from config import subject_dssm_recall_dict from config import subject_dssm_usercf_recall_dict from config import subject_fm_recall_dict from config import subject_mind_recall0_dict from config import subject_mind_recall1_dict from config import subject_pinsage_recall_dict from config import subject_youtubednn_recall_dict from config import subject_youtubednn_usercf_recall_dict from config import subject_final_recall_items_dict tqdm.pandas() warnings.filterwarnings('ignore') def combine_recall_results(user_multi_recall_dict, weight_dict=None, topk=100, save_results=True): """ 多路召回合并,将所有的召回策略得到的物品列表合并起来 """ final_recall_items_dict = {} # 对每一种召回结果按照用户进行归一化,方便后面多种召回结果,相同用户的物品之间权重相加 def norm_user_recall_items_sim(sorted_item_list): if len(sorted_item_list) < 2: return sorted_item_list min_sim = sorted_item_list[-1][1] max_sim = sorted_item_list[0][1] norm_sorted_item_list = [] for item, score in sorted_item_list: if max_sim > 0: norm_score = 1.0 * (score - min_sim) / (max_sim - min_sim) if max_sim > min_sim else 1.0 else: norm_score = 0.0 norm_sorted_item_list.append((item, norm_score)) return norm_sorted_item_list logger.info('开始多路召回合并...') for method, user_recall_items in tqdm(user_multi_recall_dict.items()): # 在计算最终召回结果的时候,也可以为每一种召回结果设置一个权重 if weight_dict == None: recall_method_weight = 1 else: recall_method_weight = weight_dict[method] for user_id, sorted_item_list in user_recall_items.items(): # 进行归一化 user_recall_items[user_id] = norm_user_recall_items_sim(sorted_item_list) for user_id, sorted_item_list in user_recall_items.items(): final_recall_items_dict.setdefault(user_id, {}) for item, score in sorted_item_list: final_recall_items_dict[user_id].setdefault(item, 0) final_recall_items_dict[user_id][item] += recall_method_weight * score final_recall_items_dict_rank = {} # 多路召回时可以控制最终的召回数量 for user, recall_item_dict in final_recall_items_dict.items(): final_recall_items_dict_rank[user] = sorted(recall_item_dict.items(), key = lambda x: x[1], reverse = True)[:topk] if save_results: logger.info('保存多路召回合并的结果') pickle.dump(final_recall_items_dict_rank, open(subject_final_recall_items_dict, 'wb')) return final_recall_items_dict_rank if __name__ == '__main__': # 定义多路召回的字典,将多路召回的结果都保存在这个字典当中 user_multi_recall_dict = {'itemcf_recall': {}, 'item_embedding_recall': {}, 'dssm_recall':{}, 'dssm_usercf_recall':{}, 'fm_recall':{}, 'mind_recall0':{}, 'mind_recall1':{}, 'pinsage_recall':{}, 'youtubednn_recall': {}, 'youtubednn_usercf_recall': {} } logger.info('加载itemcf召回结果') user_multi_recall_dict['itemcf_recall'] = pickle.load(open(subject_itemcf_recall_dict, 'rb')) logger.info('加载item embedding召回结果') user_multi_recall_dict['item_embedding_recall'] = pickle.load(open(subject_item_embedding_recall_dict, 'rb')) logger.info('加载dssm召回结果') user_multi_recall_dict['dssm_recall'] = pickle.load(open(subject_dssm_recall_dict, 'rb')) logger.info('加载dssm usercf召回结果') user_multi_recall_dict['dssm_usercf_recall'] = pickle.load(open(subject_dssm_usercf_recall_dict, 'rb')) logger.info('加载fm召回结果') user_multi_recall_dict['fm_recall'] = pickle.load(open(subject_fm_recall_dict, 'rb')) logger.info('加载mind 用户兴趣1召回结果') user_multi_recall_dict['mind_recall0'] = pickle.load(open(subject_mind_recall0_dict, 'rb')) logger.info('加载mind 用户兴趣1召回结果') user_multi_recall_dict['mind_recall1'] = pickle.load(open(subject_mind_recall1_dict, 'rb')) logger.info('加载pinsage召回结果') user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(subject_pinsage_recall_dict, 'rb')) logger.info('加载YoutubeDNN召回结果') user_multi_recall_dict['youtubednn_recall'] = pickle.load(open(subject_youtubednn_recall_dict, 'rb')) logger.info('加载youtube usercf召回结果') user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(subject_youtubednn_usercf_recall_dict, 'rb')) # 每路召回的权重给不同的值,根据前面召回的情况调整参数的值 weight_dict = {'itemcf_recall': 0.7, 'item_embedding_recall': 1.5, 'dssm_recall':2.8, 'dssm_usercf_recall':3.9, 'fm_recall':2.1, 'mind_recall0':1.9, 'mind_recall1':1.9, 'pinsage_recall':8, 'youtubednn_recall': 2.6, 'youtubednn_usercf_recall': 3.9 } # 最终合并之后每个用户召回300个物品进行排序 final_recall_items_dict_rank = combine_recall_results(user_multi_recall_dict, weight_dict, topk=300)