You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
6.0 KiB

5 months ago
from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd())
import warnings
import pickle
from config import logger
from config import shixun_item_embedding_recall_dict
from config import shixun_itemcf_recall_dict
from config import shixun_dssm_recall_dict
from config import shixun_dssm_usercf_recall_dict
from config import shixun_fm_recall_dict
from config import shixun_mind_recall0_dict
from config import shixun_mind_recall1_dict
from config import shixun_pinsage_recall_dict
from config import shixun_youtubednn_recall_dict
from config import shixun_youtubednn_usercf_recall_dict
from config import shixun_final_recall_items_dict
tqdm.pandas()
warnings.filterwarnings('ignore')
def combine_recall_results(user_multi_recall_dict, weight_dict=None, topk=100, save_results=True):
"""
多路召回合并将所有的召回策略得到的物品列表合并起来
"""
final_recall_items_dict = {}
# 对每一种召回结果按照用户进行归一化,方便后面多种召回结果,相同用户的物品之间权重相加
def norm_user_recall_items_sim(sorted_item_list):
if len(sorted_item_list) < 2:
return sorted_item_list
min_sim = sorted_item_list[-1][1]
max_sim = sorted_item_list[0][1]
norm_sorted_item_list = []
for item, score in sorted_item_list:
if max_sim > 0:
norm_score = 1.0 * (score - min_sim) / (max_sim - min_sim) if max_sim > min_sim else 1.0
else:
norm_score = 0.0
norm_sorted_item_list.append((item, norm_score))
return norm_sorted_item_list
logger.info('开始多路召回合并...')
for method, user_recall_items in tqdm(user_multi_recall_dict.items()):
# 在计算最终召回结果的时候,也可以为每一种召回结果设置一个权重
if weight_dict == None:
recall_method_weight = 1
else:
recall_method_weight = weight_dict[method]
for user_id, sorted_item_list in user_recall_items.items():
# 进行归一化
user_recall_items[user_id] = norm_user_recall_items_sim(sorted_item_list)
for user_id, sorted_item_list in user_recall_items.items():
final_recall_items_dict.setdefault(user_id, {})
for item, score in sorted_item_list:
final_recall_items_dict[user_id].setdefault(item, 0)
final_recall_items_dict[user_id][item] += recall_method_weight * score
final_recall_items_dict_rank = {}
# 多路召回时可以控制最终的召回数量
for user, recall_item_dict in final_recall_items_dict.items():
final_recall_items_dict_rank[user] = sorted(recall_item_dict.items(),
key = lambda x: x[1], reverse = True)[:topk]
if save_results:
logger.info('保存多路召回合并的结果')
pickle.dump(final_recall_items_dict_rank, open(shixun_final_recall_items_dict, 'wb'))
return final_recall_items_dict_rank
if __name__ == '__main__':
# 定义多路召回的字典,将多路召回的结果都保存在这个字典当中
user_multi_recall_dict = {'itemcf_recall': {},
'item_embedding_recall': {},
'dssm_recall':{},
'dssm_usercf_recall':{},
'fm_recall':{},
'mind_recall0':{},
'mind_recall1':{},
'pinsage_recall':{},
'youtubednn_recall': {},
'youtubednn_usercf_recall': {}
}
logger.info('加载itemcf召回结果')
user_multi_recall_dict['itemcf_recall'] = pickle.load(open(shixun_itemcf_recall_dict, 'rb'))
logger.info('加载item embedding召回结果')
user_multi_recall_dict['item_embedding_recall'] = pickle.load(open(shixun_item_embedding_recall_dict, 'rb'))
logger.info('加载dssm召回结果')
user_multi_recall_dict['dssm_recall'] = pickle.load(open(shixun_dssm_recall_dict, 'rb'))
logger.info('加载dssm usercf召回结果')
user_multi_recall_dict['dssm_usercf_recall'] = pickle.load(open(shixun_dssm_usercf_recall_dict, 'rb'))
logger.info('加载fm召回结果')
user_multi_recall_dict['fm_recall'] = pickle.load(open(shixun_fm_recall_dict, 'rb'))
logger.info('加载mind 用户兴趣1召回结果')
user_multi_recall_dict['mind_recall0'] = pickle.load(open(shixun_mind_recall0_dict, 'rb'))
logger.info('加载mind 用户兴趣1召回结果')
user_multi_recall_dict['mind_recall1'] = pickle.load(open(shixun_mind_recall1_dict, 'rb'))
logger.info('加载pinsage召回结果')
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_pinsage_recall_dict, 'rb'))
logger.info('加载YoutubeDNN召回结果')
user_multi_recall_dict['youtubednn_recall'] = pickle.load(open(shixun_youtubednn_recall_dict, 'rb'))
logger.info('加载youtube usercf召回结果')
user_multi_recall_dict['youtubednn_usercf_recall'] = pickle.load(open(shixun_youtubednn_usercf_recall_dict, 'rb'))
# 每路召回的权重给不同的值,根据前面召回的情况调整参数的值
weight_dict = {'itemcf_recall': 0.7,
'item_embedding_recall': 10,
'dssm_recall':1.2,
'dssm_usercf_recall':12,
'fm_recall':1,
'mind_recall0':0.8,
'mind_recall1':0.7,
'pinsage_recall':60,
'youtubednn_recall': 1,
'youtubednn_usercf_recall': 12}
# 最终合并之后每个用户召回300个物品进行排序
final_recall_items_dict_rank = combine_recall_results(user_multi_recall_dict, weight_dict, topk=300)