You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.8 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import faiss
import pickle
from datetime import datetime
from config import logger
from config import embedding_dim
from config import test_user_id
from config import subject_dssm_user_emb_dict
from config import subject_dssm_user_embedding_data
from config import subject_dssm_user_faiss_model_path
from config import subject_dssm_user_embedding_index_dict
from config import subject_dssm_item_emb_dict
from config import subject_dssm_item_embedding_data
from config import subject_dssm_item_faiss_model_path
from config import subject_dssm_item_embedding_index_dict
logger.info('加载DSSM召回数据')
dssm_user_emb_dict = pickle.load(open(subject_dssm_user_emb_dict, 'rb'))
dssm_user_embedding_data = pickle.load(open(subject_dssm_user_embedding_data, 'rb'))
dssm_user_embedding_index_dict = pickle.load(open(subject_dssm_user_embedding_index_dict, 'rb'))
dssm_item_emb_dict = pickle.load(open(subject_dssm_item_emb_dict, 'rb'))
dssm_item_embedding_data = pickle.load(open(subject_dssm_item_embedding_data, 'rb'))
dssm_item_embedding_index_dict = pickle.load(open(subject_dssm_item_embedding_index_dict, 'rb'))
logger.info('加载dssm召回模型')
dssm_user_faiss_model = faiss.read_index(subject_dssm_user_faiss_model_path)
dssm_item_faiss_model = faiss.read_index(subject_dssm_item_faiss_model_path)
def dssm_recall(user_id, topk=20):
"""
通过dssm user embedding的相似度推荐学习伙伴
"""
start_time = datetime.now()
logger.info(f"本次需要进行dssm召回的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
if (user_id not in dssm_user_emb_dict):
return recommend_results
# 取出用户向量
user_embs = dssm_user_emb_dict[user_id]
# reshape为二维
user_embs = user_embs.reshape(-1, embedding_dim)
# 找topk个相似的物品向量
D, I = dssm_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_sim = list(D.ravel())
top_k_item = {}
top_k_item.clear()
# 还原相似向量索引对应的用户
for i, index in enumerate(top_k_index):
cur_user_id = dssm_item_embedding_index_dict[index]
# 过滤第一个最相似的用户是自己
if cur_user_id != user_id:
top_k_item[cur_user_id] = top_k_sim[i]
top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_item
if __name__ == '__main__':
recall_results = dssm_recall(user_id=test_user_id, topk=20)
print(recall_results)