You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.5 KiB
189 lines
6.5 KiB
import numpy as np
|
|
import pandas as pd
|
|
import json
|
|
import faiss
|
|
import pickle
|
|
import config
|
|
import random
|
|
from datetime import datetime
|
|
from utils import random_dict_order
|
|
from config import logger, embedding_dim
|
|
from config import shixun_dssm_item_faiss_model_path
|
|
from config import shixun_dssm_item_embedding_index_dict
|
|
from config import shixun_dssm_item_emb_dict
|
|
from config import shixun_id_to_name_dict_data
|
|
from config import test_shixun_id, test_shixun_name
|
|
from matching.shixun.hnsw_faiss import HNSW
|
|
|
|
|
|
logger.info('加载相关实训召回HNSW模型')
|
|
hnsw = HNSW(config.word2vec_model_path,
|
|
config.shixun_faiss_w2v_path,
|
|
config.ef_construction,
|
|
config.M,
|
|
config.shixuns_fassi_model_path,
|
|
config.shixuns_data_path)
|
|
|
|
dssm_item_faiss_model = faiss.read_index(shixun_dssm_item_faiss_model_path)
|
|
logger.info('加载相关实训召回字典')
|
|
dssm_item_embedding_index_dict = pickle.load(open(shixun_dssm_item_embedding_index_dict, 'rb'))
|
|
dssm_item_emb_dict = pickle.load(open(shixun_dssm_item_emb_dict, 'rb'))
|
|
shixun_id_to_name_dict = pickle.load(open(shixun_id_to_name_dict_data, 'rb'))
|
|
|
|
|
|
def relevant_shixun_recommend(shixun_id, shixun_name, topk=10):
|
|
"""
|
|
先通过dssm item embedding相似性进行推荐
|
|
再通过物品名称embedding相似性进行推荐
|
|
"""
|
|
start_time = datetime.now()
|
|
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
|
|
|
|
recommend_results = {}
|
|
recommend_results.clear()
|
|
|
|
# 先获取dssm item embedding推荐的结果
|
|
recommend_results_dssm = {}
|
|
recommend_results_dssm.clear()
|
|
recommend_results_dssm = relevant_shixun_recommend_by_dssm_item_embedding(
|
|
shixun_id, shixun_name, topk, False)
|
|
|
|
# 再获取物品名称embedding推荐的结果
|
|
recommend_results_faiss = {}
|
|
recommend_results_faiss.clear()
|
|
recommend_results_faiss = relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk, False)
|
|
|
|
if (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) > 0):
|
|
recommend_results = recommend_results_faiss.copy()
|
|
|
|
elif (len(recommend_results_dssm) > 0) and (len(recommend_results_faiss) == 0):
|
|
recommend_results = recommend_results_dssm.copy()
|
|
|
|
elif (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) == 0):
|
|
pass
|
|
else:
|
|
# dssm item embedding取五分之四
|
|
first_pick = (topk // 5) * 4
|
|
|
|
# 实训名称embedding取五分之一
|
|
second_pick = topk - first_pick
|
|
|
|
# 随机打乱通过实训名称推荐的
|
|
recommend_results_faiss = random_dict_order(recommend_results_faiss)
|
|
|
|
count = 0
|
|
value_list = []
|
|
value_list.clear()
|
|
for key, value in recommend_results_faiss.items():
|
|
# 过滤实训名称相同的
|
|
if value not in value_list:
|
|
value_list.append(value)
|
|
recommend_results[key] = value
|
|
count += 1
|
|
if count == second_pick:
|
|
break
|
|
|
|
count = 0
|
|
for key, value in recommend_results_dssm.items():
|
|
recommend_results[key] = value
|
|
count += 1
|
|
if count == first_pick:
|
|
break
|
|
|
|
# 计算耗时毫秒
|
|
end_time = datetime.utcnow()
|
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
|
|
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
|
return recommend_results
|
|
|
|
|
|
def relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk=10, verbose=True):
|
|
"""
|
|
通过item embedding的相似度推荐相关实训
|
|
"""
|
|
start_time = datetime.now()
|
|
if verbose:
|
|
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
|
|
|
|
recommend_results = {}
|
|
recommend_results.clear()
|
|
|
|
_, top_k_Item = hnsw.search(shixun_name, k=topk)
|
|
recommend_results = {cur_shixun_id: cur_shixun_name for cur_shixun_id, cur_shixun_name in top_k_Item}
|
|
|
|
# 计算耗时毫秒
|
|
end_time = datetime.utcnow()
|
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
|
|
|
if verbose:
|
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
|
return recommend_results
|
|
|
|
|
|
def relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=10, verbose=True):
|
|
"""
|
|
通过dssm item embedding的相似度推荐相关实训
|
|
"""
|
|
start_time = datetime.now()
|
|
if verbose:
|
|
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
|
|
|
|
recommend_results = {}
|
|
recommend_results.clear()
|
|
|
|
if (shixun_id not in dssm_item_emb_dict) or (shixun_id not in shixun_id_to_name_dict):
|
|
return recommend_results
|
|
|
|
# 取出物品向量
|
|
item_embs = dssm_item_emb_dict[shixun_id]
|
|
|
|
# reshape为二维
|
|
item_embs = item_embs.reshape(-1, embedding_dim)
|
|
|
|
# 找topk个相似的向量
|
|
D, I = dssm_item_faiss_model.search(np.ascontiguousarray(item_embs), topk + 1)
|
|
|
|
top_k_index = list(I.ravel())
|
|
top_k_item = {}
|
|
top_k_item.clear()
|
|
|
|
# 还原相似向量索引对应的用户
|
|
for index in top_k_index:
|
|
# 取出物品embedding索引对应的物品ID
|
|
cur_item_id = dssm_item_embedding_index_dict[index]
|
|
|
|
# 过滤第一个最相似的物品是自己
|
|
if cur_item_id != shixun_id:
|
|
# 取出物品名称
|
|
cur_item_name = shixun_id_to_name_dict[cur_item_id]
|
|
top_k_item[cur_item_id] = cur_item_name
|
|
|
|
# 计算耗时毫秒
|
|
end_time = datetime.utcnow()
|
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
|
|
|
if verbose:
|
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
|
|
|
return top_k_item
|
|
|
|
|
|
if __name__ == '__main__':
|
|
shixun_name = test_shixun_name
|
|
print('*' * 50+"item embedding"+'*' * 50)
|
|
recommend_results = relevant_shixun_recommend_by_faiss(0, shixun_name, topk=20)
|
|
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
|
|
|
|
shixun_id = test_shixun_id
|
|
shixun_name = shixun_id_to_name_dict[shixun_id]
|
|
|
|
print('*' * 50+"DSSM"+'*' * 50)
|
|
|
|
recommend_results = relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=20)
|
|
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
|
|
|
|
print('*' * 100)
|
|
|
|
recommend_results = relevant_shixun_recommend(shixun_id, shixun_name, topk=20)
|
|
print(json.dumps(recommend_results, ensure_ascii=False, indent=4)) |