You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
5.5 KiB

5 months ago
import numpy as np
import faiss
import pickle
from datetime import datetime
from utils import merge_dict
from config import logger
from config import embedding_dim
from config import shixun_user_w2v_emb_dict
from config import shixuns_user_word2vec_faiss_model_path
from config import shixun_dssm_user_emb_dict
from config import shixun_dssm_user_faiss_model_path
from config import shixun_dssm_user_embedding_index_dict
from config import shixun_user_embedding_index_dict
from config import shixun_user_item_time_dict_data
from config import test_user_id
logger.info('加载学习伙伴召回模型')
user_word2vec_faiss_model = faiss.read_index(shixuns_user_word2vec_faiss_model_path)
dssm_user_faiss_model = faiss.read_index(shixun_dssm_user_faiss_model_path)
logger.info('加载学习伙伴召回字典')
user_w2v_emb_dict = pickle.load(open(shixun_user_w2v_emb_dict, 'rb'))
user_embedding_index_dict = pickle.load(open(shixun_user_embedding_index_dict, 'rb'))
dssm_user_emb_dict = pickle.load(open(shixun_dssm_user_emb_dict, 'rb'))
dssm_user_embedding_index_dict = pickle.load(open(shixun_dssm_user_embedding_index_dict, 'rb'))
user_item_time_dict = pickle.load(open(shixun_user_item_time_dict_data, 'rb'))
def shixun_partner_recommend(user_id, topk=20, method='1'):
"""
先通过dssm user embedding进行推荐
不够topk个再通过用户embedding进行推荐
"""
start_time = datetime.now()
logger.info(f"本次需要进行学习伙伴推荐的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
recommend_results_dssm = partner_recommend_by_dssm_user_embedding(user_id, topk, False)
if len(recommend_results_dssm) < topk:
recommend_results_w2v_embedding = partner_recommend_by_user_w2v_embedding(
user_id, topk - len(recommend_results_dssm), False)
recommend_results = merge_dict(recommend_results_dssm, recommend_results_w2v_embedding)
else:
recommend_results = recommend_results_dssm
if method == '1':
recommend_results = [key for key, value in recommend_results.items()]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return recommend_results
def partner_recommend_by_user_w2v_embedding(user_id, topk=20, verbose=True):
"""
通过word2vec user embedding的相似度推荐学习伙伴
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行学习伙伴推荐的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
if (user_id not in user_w2v_emb_dict) or (user_id not in user_item_time_dict):
return recommend_results
# 取出用户向量
user_embs = user_w2v_emb_dict[user_id]
# reshape为二维
user_embs = user_embs.reshape(-1, embedding_dim)
# 找topk个相似的向量
D, I = user_word2vec_faiss_model.search(np.ascontiguousarray(user_embs), topk)
top_k_index = list(I.ravel())
top_k_user = {}
top_k_user.clear()
# 还原相似向量索引对应的用户
for index in top_k_index:
user_id, item_id_list = user_embedding_index_dict[index][0]
top_k_user[user_id] = [int(item_id) for item_id in item_id_list]
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_user
def partner_recommend_by_dssm_user_embedding(user_id, topk=20, verbose=True):
"""
通过dssm user embedding的相似度推荐学习伙伴
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行学习伙伴推荐的用户ID: {user_id}")
recommend_results = {}
recommend_results.clear()
if (user_id not in dssm_user_emb_dict) or (user_id not in user_item_time_dict):
return recommend_results
# 取出用户向量
user_embs = dssm_user_emb_dict[user_id]
# reshape为二维
user_embs = user_embs.reshape(-1, embedding_dim)
# 找topk个相似的向量
D, I = dssm_user_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_user = {}
top_k_user.clear()
# 还原相似向量索引对应的用户
for index in top_k_index:
cur_user_id = dssm_user_embedding_index_dict[index]
# 过滤第一个最相似的用户是自己
if cur_user_id != user_id:
item_id_list = [int(item_id) for item_id, _ in user_item_time_dict[cur_user_id]]
top_k_user[cur_user_id] = item_id_list
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_user
if __name__ == '__main__':
user_id = test_user_id
item_id_list = [int(item_id) for item_id, time in user_item_time_dict[user_id]]
print(str(user_id) + ':')
print(item_id_list)
print('*' * 100)
recommend_results = shixun_partner_recommend(user_id, topk=20, method='0')
for sim_user_id, sim_item_id_list in recommend_results.items():
print(str(sim_user_id) + ':')
print(sim_item_id_list)