You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
2.9 KiB
82 lines
2.9 KiB
import os
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
import numpy as np
|
|
import faiss
|
|
import pickle
|
|
from datetime import datetime
|
|
from config import logger
|
|
from config import embedding_dim
|
|
from config import test_user_id
|
|
from config import shixun_youtube_user_emb_dict
|
|
from config import shixun_youtube_user_embedding_data
|
|
from config import shixun_youtube_user_faiss_model_path
|
|
from config import shixun_youtube_user_embedding_index_dict
|
|
from config import shixun_youtube_item_emb_dict
|
|
from config import shixun_youtube_item_embedding_data
|
|
from config import shixun_youtube_item_faiss_model_path
|
|
from config import shixun_youtube_item_embedding_index_dict
|
|
|
|
|
|
logger.info('加载youtubednn召回数据')
|
|
youtube_user_emb_dict = pickle.load(open(shixun_youtube_user_emb_dict, 'rb'))
|
|
youtube_user_embedding_data = pickle.load(open(shixun_youtube_user_embedding_data, 'rb'))
|
|
youtube_user_embedding_index_dict = pickle.load(open(shixun_youtube_user_embedding_index_dict, 'rb'))
|
|
|
|
youtube_item_emb_dict = pickle.load(open(shixun_youtube_item_emb_dict, 'rb'))
|
|
youtube_item_embedding_data = pickle.load(open(shixun_youtube_item_embedding_data, 'rb'))
|
|
youtube_item_embedding_index_dict = pickle.load(open(shixun_youtube_item_embedding_index_dict, 'rb'))
|
|
|
|
logger.info('加载youtubednn召回模型')
|
|
youtube_user_faiss_model = faiss.read_index(shixun_youtube_user_faiss_model_path)
|
|
youtube_item_faiss_model = faiss.read_index(shixun_youtube_item_faiss_model_path)
|
|
|
|
|
|
def youtubednn_recall(user_id, topk=20):
|
|
"""
|
|
通过youtube user embedding的相似度推荐学习伙伴
|
|
"""
|
|
start_time = datetime.now()
|
|
logger.info(f"本次需要进行YoutubeDNN召回的用户ID: {user_id}")
|
|
|
|
recommend_results = {}
|
|
recommend_results.clear()
|
|
|
|
if (user_id not in youtube_user_emb_dict):
|
|
return recommend_results
|
|
|
|
# 取出用户向量
|
|
user_embs = youtube_user_emb_dict[user_id]
|
|
|
|
# reshape为二维
|
|
user_embs = user_embs.reshape(-1, embedding_dim)
|
|
|
|
# 找topk个相似的物品向量
|
|
D, I = youtube_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
|
|
|
|
top_k_index = list(I.ravel())
|
|
top_k_sim = list(D.ravel())
|
|
top_k_item = {}
|
|
top_k_item.clear()
|
|
|
|
# 还原相似向量索引对应的用户
|
|
for i, index in enumerate(top_k_index):
|
|
cur_user_id = youtube_item_embedding_index_dict[index]
|
|
|
|
# 过滤第一个最相似的用户是自己
|
|
if cur_user_id != user_id:
|
|
top_k_item[cur_user_id] = top_k_sim[i]
|
|
|
|
top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk]
|
|
|
|
# 计算耗时毫秒
|
|
end_time = datetime.utcnow()
|
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
|
|
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
|
return top_k_item
|
|
|
|
|
|
if __name__ == '__main__':
|
|
recall_results = youtubednn_recall(user_id=test_user_id, topk=20)
|
|
print(recall_results) |