You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.8 KiB

import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import faiss
import pickle
from datetime import datetime
from config import logger
from config import embedding_dim
from config import test_user_id
from config import shixun_dssm_user_emb_dict
from config import shixun_dssm_user_embedding_data
from config import shixun_dssm_user_faiss_model_path
from config import shixun_dssm_user_embedding_index_dict
from config import shixun_dssm_item_emb_dict
from config import shixun_dssm_item_embedding_data
from config import shixun_dssm_item_faiss_model_path
from config import shixun_dssm_item_embedding_index_dict
logger.info('加载DSSM召回数据')
dssm_user_emb_dict = pickle.load(open(shixun_dssm_user_emb_dict, 'rb'))
dssm_user_embedding_data = pickle.load(open(shixun_dssm_user_embedding_data, 'rb'))
dssm_user_embedding_index_dict = pickle.load(open(shixun_dssm_user_embedding_index_dict, 'rb'))
dssm_item_emb_dict = pickle.load(open(shixun_dssm_item_emb_dict, 'rb'))
dssm_item_embedding_data = pickle.load(open(shixun_dssm_item_embedding_data, 'rb'))
dssm_item_embedding_index_dict = pickle.load(open(shixun_dssm_item_embedding_index_dict, 'rb'))
logger.info('加载dssm召回模型')
dssm_user_faiss_model = faiss.read_index(shixun_dssm_user_faiss_model_path)
dssm_item_faiss_model = faiss.read_index(shixun_dssm_item_faiss_model_path)
def dssm_recall(user_id, topk=20):
"""
通过dssm user embedding的相似度推荐学习伙伴
"""
start_time = datetime.now()
logger.info(f"本次需要进行dssm召回的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
if (user_id not in dssm_user_emb_dict):
return recommend_results
# 取出用户向量
user_embs = dssm_user_emb_dict[user_id]
# reshape为二维
user_embs = user_embs.reshape(-1, embedding_dim)
# 找topk个相似的物品向量
D, I = dssm_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_sim = list(D.ravel())
top_k_item = {}
top_k_item.clear()
# 还原相似向量索引对应的用户
for i, index in enumerate(top_k_index):
cur_user_id = dssm_item_embedding_index_dict[index]
# 过滤第一个最相似的用户是自己
if cur_user_id != user_id:
top_k_item[cur_user_id] = top_k_sim[i]
top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_item
if __name__ == '__main__':
recall_results = dssm_recall(user_id=test_user_id, topk=20)
print(recall_results)