You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.8 KiB

import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import faiss
import pickle
from datetime import datetime
from config import logger
from config import embedding_dim
from config import test_user_id
from config import subject_mind_user_emb0_dict
from config import subject_mind_user_embedding0_data
from config import subject_mind_user0_faiss_model_path
from config import subject_mind_user_embedding0_index_dict
from config import subject_mind_item_emb_dict
from config import subject_mind_item_embedding_data
from config import subject_mind_item_faiss_model_path
from config import subject_mind_item_embedding_index_dict
logger.info('加载MIND召回数据')
mind_user_emb_dict = pickle.load(open(subject_mind_user_emb0_dict, 'rb'))
mind_user_embedding_data = pickle.load(open(subject_mind_user_embedding0_data, 'rb'))
mind_user_embedding_index_dict = pickle.load(open(subject_mind_user_embedding0_index_dict, 'rb'))
mind_item_emb_dict = pickle.load(open(subject_mind_item_emb_dict, 'rb'))
mind_item_embedding_data = pickle.load(open(subject_mind_item_embedding_data, 'rb'))
mind_item_embedding_index_dict = pickle.load(open(subject_mind_item_embedding_index_dict, 'rb'))
logger.info('加载MIND召回模型')
mind_user_faiss_model = faiss.read_index(subject_mind_user0_faiss_model_path)
mind_item_faiss_model = faiss.read_index(subject_mind_item_faiss_model_path)
def mind0_recall(user_id, topk=20):
"""
通过mind user0 embedding的相似度推荐学习伙伴
"""
start_time = datetime.now()
logger.info(f"本次需要进行MIND召回的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
if (user_id not in mind_user_emb_dict):
return recommend_results
# 取出用户向量
user_embs = mind_user_emb_dict[user_id]
# reshape为二维
user_embs = user_embs.reshape(-1, embedding_dim)
# 找topk个相似的物品向量
D, I = mind_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_sim = list(D.ravel())
top_k_item = {}
top_k_item.clear()
# 还原相似向量索引对应的用户
for i, index in enumerate(top_k_index):
cur_user_id = mind_item_embedding_index_dict[index]
# 过滤第一个最相似的用户是自己
if cur_user_id != user_id:
top_k_item[cur_user_id] = top_k_sim[i]
top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_item
if __name__ == '__main__':
recall_results = mind0_recall(user_id=test_user_id, topk=20)
print(recall_results)