You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
6.5 KiB

5 months ago
import numpy as np
import pandas as pd
import json
import faiss
import pickle
import config
import random
from datetime import datetime
from utils import random_dict_order
from config import logger, embedding_dim
from config import shixun_dssm_item_faiss_model_path
from config import shixun_dssm_item_embedding_index_dict
from config import shixun_dssm_item_emb_dict
from config import shixun_id_to_name_dict_data
from config import test_shixun_id, test_shixun_name
from matching.shixun.hnsw_faiss import HNSW
logger.info('加载相关实训召回HNSW模型')
hnsw = HNSW(config.word2vec_model_path,
config.shixun_faiss_w2v_path,
config.ef_construction,
config.M,
config.shixuns_fassi_model_path,
config.shixuns_data_path)
dssm_item_faiss_model = faiss.read_index(shixun_dssm_item_faiss_model_path)
logger.info('加载相关实训召回字典')
dssm_item_embedding_index_dict = pickle.load(open(shixun_dssm_item_embedding_index_dict, 'rb'))
dssm_item_emb_dict = pickle.load(open(shixun_dssm_item_emb_dict, 'rb'))
shixun_id_to_name_dict = pickle.load(open(shixun_id_to_name_dict_data, 'rb'))
def relevant_shixun_recommend(shixun_id, shixun_name, topk=10):
"""
先通过dssm item embedding相似性进行推荐
再通过物品名称embedding相似性进行推荐
"""
start_time = datetime.now()
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
# 先获取dssm item embedding推荐的结果
recommend_results_dssm = {}
recommend_results_dssm.clear()
recommend_results_dssm = relevant_shixun_recommend_by_dssm_item_embedding(
shixun_id, shixun_name, topk, False)
# 再获取物品名称embedding推荐的结果
recommend_results_faiss = {}
recommend_results_faiss.clear()
recommend_results_faiss = relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk, False)
if (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) > 0):
recommend_results = recommend_results_faiss.copy()
elif (len(recommend_results_dssm) > 0) and (len(recommend_results_faiss) == 0):
recommend_results = recommend_results_dssm.copy()
elif (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) == 0):
pass
else:
# dssm item embedding取五分之四
first_pick = (topk // 5) * 4
# 实训名称embedding取五分之一
second_pick = topk - first_pick
# 随机打乱通过实训名称推荐的
recommend_results_faiss = random_dict_order(recommend_results_faiss)
count = 0
value_list = []
value_list.clear()
for key, value in recommend_results_faiss.items():
# 过滤实训名称相同的
if value not in value_list:
value_list.append(value)
recommend_results[key] = value
count += 1
if count == second_pick:
break
count = 0
for key, value in recommend_results_dssm.items():
recommend_results[key] = value
count += 1
if count == first_pick:
break
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return recommend_results
def relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk=10, verbose=True):
"""
通过item embedding的相似度推荐相关实训
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
_, top_k_Item = hnsw.search(shixun_name, k=topk)
recommend_results = {cur_shixun_id: cur_shixun_name for cur_shixun_id, cur_shixun_name in top_k_Item}
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return recommend_results
def relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=10, verbose=True):
"""
通过dssm item embedding的相似度推荐相关实训
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
if (shixun_id not in dssm_item_emb_dict) or (shixun_id not in shixun_id_to_name_dict):
return recommend_results
# 取出物品向量
item_embs = dssm_item_emb_dict[shixun_id]
# reshape为二维
item_embs = item_embs.reshape(-1, embedding_dim)
# 找topk个相似的向量
D, I = dssm_item_faiss_model.search(np.ascontiguousarray(item_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_item = {}
top_k_item.clear()
# 还原相似向量索引对应的用户
for index in top_k_index:
# 取出物品embedding索引对应的物品ID
cur_item_id = dssm_item_embedding_index_dict[index]
# 过滤第一个最相似的物品是自己
if cur_item_id != shixun_id:
# 取出物品名称
cur_item_name = shixun_id_to_name_dict[cur_item_id]
top_k_item[cur_item_id] = cur_item_name
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_item
if __name__ == '__main__':
shixun_name = test_shixun_name
print('*' * 50+"item embedding"+'*' * 50)
recommend_results = relevant_shixun_recommend_by_faiss(0, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
shixun_id = test_shixun_id
shixun_name = shixun_id_to_name_dict[shixun_id]
print('*' * 50+"DSSM"+'*' * 50)
recommend_results = relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
print('*' * 100)
recommend_results = relevant_shixun_recommend(shixun_id, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))