You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
6.0 KiB

from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd())
import warnings
import pickle
from config import logger
from config import shixun_item_embedding_recall_dict
from config import shixun_itemcf_recall_dict
from config import shixun_dssm_recall_dict
from config import shixun_dssm_usercf_recall_dict
from config import shixun_fm_recall_dict
from config import shixun_mind_recall0_dict
from config import shixun_mind_recall1_dict
from config import shixun_pinsage_recall_dict
from config import shixun_youtubednn_recall_dict
from config import shixun_youtubednn_usercf_recall_dict
from config import shixun_final_recall_items_dict
tqdm.pandas()
warnings.filterwarnings('ignore')
def combine_recall_results(user_multi_recall_dict, weight_dict=None, topk=100, save_results=True):
"""
多路召回合并,将所有的召回策略得到的物品列表合并起来
"""
final_recall_items_dict = {}
# 对每一种召回结果按照用户进行归一化,方便后面多种召回结果,相同用户的物品之间权重相加
def norm_user_recall_items_sim(sorted_item_list):
if len(sorted_item_list) < 2:
return sorted_item_list
min_sim = sorted_item_list[-1][1]
max_sim = sorted_item_list[0][1]
norm_sorted_item_list = []
for item, score in sorted_item_list:
if max_sim > 0:
norm_score = 1.0 * (score - min_sim) / (max_sim - min_sim) if max_sim > min_sim else 1.0
else:
norm_score = 0.0
norm_sorted_item_list.append((item, norm_score))
return norm_sorted_item_list
logger.info('开始多路召回合并...')
for method, user_recall_items in tqdm(user_multi_recall_dict.items()):
# 在计算最终召回结果的时候,也可以为每一种召回结果设置一个权重
if weight_dict == None:
recall_method_weight = 1
else:
recall_method_weight = weight_dict[method]
for user_id, sorted_item_list in user_recall_items.items():
# 进行归一化
user_recall_items[user_id] = norm_user_recall_items_sim(sorted_item_list)
for user_id, sorted_item_list in user_recall_items.items():
final_recall_items_dict.setdefault(user_id, {})
for item, score in sorted_item_list:
final_recall_items_dict[user_id].setdefault(item, 0)
final_recall_items_dict[user_id][item] += recall_method_weight * score
final_recall_items_dict_rank = {}
# 多路召回时可以控制最终的召回数量
for user, recall_item_dict in final_recall_items_dict.items():
final_recall_items_dict_rank[user] = sorted(recall_item_dict.items(),
key = lambda x: x[1], reverse = True)[:topk]
if save_results:
logger.info('保存多路召回合并的结果')
pickle.dump(final_recall_items_dict_rank, open(shixun_final_recall_items_dict, 'wb'))
return final_recall_items_dict_rank
if __name__ == '__main__':
# 定义多路召回的字典,将多路召回的结果都保存在这个字典当中
user_multi_recall_dict = {'itemcf_recall': {},
'item_embedding_recall': {},
'dssm_recall':{},
'dssm_usercf_recall':{},
'fm_recall':{},
'mind_recall0':{},
'mind_recall1':{},
'pinsage_recall':{},
'youtubednn_recall': {},
'youtubednn_usercf_recall': {}
}
logger.info('加载itemcf召回结果')
user_multi_recall_dict['itemcf_recall'] = pickle.load(open(shixun_itemcf_recall_dict, 'rb'))
logger.info('加载item embedding召回结果')
user_multi_recall_dict['item_embedding_recall'] = pickle.load(open(shixun_item_embedding_recall_dict, 'rb'))
logger.info('加载dssm召回结果')
user_multi_recall_dict['dssm_recall'] = pickle.load(open(shixun_dssm_recall_dict, 'rb'))
logger.info('加载dssm usercf召回结果')
user_multi_recall_dict['dssm_usercf_recall'] = pickle.load(open(shixun_dssm_usercf_recall_dict, 'rb'))
logger.info('加载fm召回结果')
user_multi_recall_dict['fm_recall'] = pickle.load(open(shixun_fm_recall_dict, 'rb'))
logger.info('加载mind 用户兴趣1召回结果')
user_multi_recall_dict['mind_recall0'] = pickle.load(open(shixun_mind_recall0_dict, 'rb'))
logger.info('加载mind 用户兴趣1召回结果')
user_multi_recall_dict['mind_recall1'] = pickle.load(open(shixun_mind_recall1_dict, 'rb'))
logger.info('加载pinsage召回结果')
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_pinsage_recall_dict, 'rb'))
logger.info('加载YoutubeDNN召回结果')
user_multi_recall_dict['youtubednn_recall'] = pickle.load(open(shixun_youtubednn_recall_dict, 'rb'))
logger.info('加载youtube usercf召回结果')
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_youtubednn_usercf_recall_dict, 'rb'))
# 每路召回的权重给不同的值,根据前面召回的情况调整参数的值
weight_dict = {'itemcf_recall': 0.7,
'item_embedding_recall': 10,
'dssm_recall':1.2,
'dssm_usercf_recall':12,
'fm_recall':1,
'mind_recall0':0.8,
'mind_recall1':0.7,
'pinsage_recall':60,
'youtubednn_recall': 1,
'youtubednn_usercf_recall': 12}
# 最终合并之后每个用户召回300个物品进行排序
final_recall_items_dict_rank = combine_recall_results(user_multi_recall_dict, weight_dict, topk=300)