You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.8 KiB
81 lines
2.8 KiB
5 months ago
|
import os
|
||
|
import sys
|
||
|
sys.path.append(os.getcwd())
|
||
|
import numpy as np
|
||
|
import faiss
|
||
|
import pickle
|
||
|
from datetime import datetime
|
||
|
from config import logger
|
||
|
from config import embedding_dim
|
||
|
from config import test_user_id
|
||
|
from config import shixun_mind_user_emb0_dict
|
||
|
from config import shixun_mind_user_embedding0_data
|
||
|
from config import shixun_mind_user0_faiss_model_path
|
||
|
from config import shixun_mind_user_embedding0_index_dict
|
||
|
from config import shixun_mind_item_emb_dict
|
||
|
from config import shixun_mind_item_embedding_data
|
||
|
from config import shixun_mind_item_faiss_model_path
|
||
|
from config import shixun_mind_item_embedding_index_dict
|
||
|
|
||
|
logger.info('加载MIND召回数据')
|
||
|
mind_user_emb_dict = pickle.load(open(shixun_mind_user_emb0_dict, 'rb'))
|
||
|
mind_user_embedding_data = pickle.load(open(shixun_mind_user_embedding0_data, 'rb'))
|
||
|
mind_user_embedding_index_dict = pickle.load(open(shixun_mind_user_embedding0_index_dict, 'rb'))
|
||
|
|
||
|
mind_item_emb_dict = pickle.load(open(shixun_mind_item_emb_dict, 'rb'))
|
||
|
mind_item_embedding_data = pickle.load(open(shixun_mind_item_embedding_data, 'rb'))
|
||
|
mind_item_embedding_index_dict = pickle.load(open(shixun_mind_item_embedding_index_dict, 'rb'))
|
||
|
|
||
|
logger.info('加载MIND召回模型')
|
||
|
mind_user_faiss_model = faiss.read_index(shixun_mind_user0_faiss_model_path)
|
||
|
mind_item_faiss_model = faiss.read_index(shixun_mind_item_faiss_model_path)
|
||
|
|
||
|
|
||
|
def mind0_recall(user_id, topk=20):
|
||
|
"""
|
||
|
通过mind user0 embedding的相似度推荐学习伙伴
|
||
|
"""
|
||
|
start_time = datetime.now()
|
||
|
logger.info(f"本次需要进行MIND召回的用户ID: {user_id}")
|
||
|
|
||
|
recommend_results = {}
|
||
|
recommend_results.clear()
|
||
|
|
||
|
if (user_id not in mind_user_emb_dict):
|
||
|
return recommend_results
|
||
|
|
||
|
# 取出用户向量
|
||
|
user_embs = mind_user_emb_dict[user_id]
|
||
|
|
||
|
# reshape为二维
|
||
|
user_embs = user_embs.reshape(-1, embedding_dim)
|
||
|
|
||
|
# 找topk个相似的物品向量
|
||
|
D, I = mind_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
|
||
|
|
||
|
top_k_index = list(I.ravel())
|
||
|
top_k_sim = list(D.ravel())
|
||
|
top_k_item = {}
|
||
|
top_k_item.clear()
|
||
|
|
||
|
# 还原相似向量索引对应的用户
|
||
|
for i, index in enumerate(top_k_index):
|
||
|
cur_user_id = mind_item_embedding_index_dict[index]
|
||
|
|
||
|
# 过滤第一个最相似的用户是自己
|
||
|
if cur_user_id != user_id:
|
||
|
top_k_item[cur_user_id] = top_k_sim[i]
|
||
|
|
||
|
top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk]
|
||
|
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
||
|
return top_k_item
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
recall_results = mind0_recall(user_id=test_user_id, topk=20)
|
||
|
print(recall_results)
|