You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
6.5 KiB

import numpy as np
import pandas as pd
import json
import faiss
import pickle
import config
import random
from datetime import datetime
from utils import random_dict_order
from config import logger, embedding_dim
from config import shixun_dssm_item_faiss_model_path
from config import shixun_dssm_item_embedding_index_dict
from config import shixun_dssm_item_emb_dict
from config import shixun_id_to_name_dict_data
from config import test_shixun_id, test_shixun_name
from matching.shixun.hnsw_faiss import HNSW
logger.info('加载相关实训召回HNSW模型')
hnsw = HNSW(config.word2vec_model_path,
config.shixun_faiss_w2v_path,
config.ef_construction,
config.M,
config.shixuns_fassi_model_path,
config.shixuns_data_path)
dssm_item_faiss_model = faiss.read_index(shixun_dssm_item_faiss_model_path)
logger.info('加载相关实训召回字典')
dssm_item_embedding_index_dict = pickle.load(open(shixun_dssm_item_embedding_index_dict, 'rb'))
dssm_item_emb_dict = pickle.load(open(shixun_dssm_item_emb_dict, 'rb'))
shixun_id_to_name_dict = pickle.load(open(shixun_id_to_name_dict_data, 'rb'))
def relevant_shixun_recommend(shixun_id, shixun_name, topk=10):
"""
先通过dssm item embedding相似性进行推荐
再通过物品名称embedding相似性进行推荐
"""
start_time = datetime.now()
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
# 先获取dssm item embedding推荐的结果
recommend_results_dssm = {}
recommend_results_dssm.clear()
recommend_results_dssm = relevant_shixun_recommend_by_dssm_item_embedding(
shixun_id, shixun_name, topk, False)
# 再获取物品名称embedding推荐的结果
recommend_results_faiss = {}
recommend_results_faiss.clear()
recommend_results_faiss = relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk, False)
if (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) > 0):
recommend_results = recommend_results_faiss.copy()
elif (len(recommend_results_dssm) > 0) and (len(recommend_results_faiss) == 0):
recommend_results = recommend_results_dssm.copy()
elif (len(recommend_results_dssm) == 0) and (len(recommend_results_faiss) == 0):
pass
else:
# dssm item embedding取五分之四
first_pick = (topk // 5) * 4
# 实训名称embedding取五分之一
second_pick = topk - first_pick
# 随机打乱通过实训名称推荐的
recommend_results_faiss = random_dict_order(recommend_results_faiss)
count = 0
value_list = []
value_list.clear()
for key, value in recommend_results_faiss.items():
# 过滤实训名称相同的
if value not in value_list:
value_list.append(value)
recommend_results[key] = value
count += 1
if count == second_pick:
break
count = 0
for key, value in recommend_results_dssm.items():
recommend_results[key] = value
count += 1
if count == first_pick:
break
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return recommend_results
def relevant_shixun_recommend_by_faiss(shixun_id, shixun_name, topk=10, verbose=True):
"""
通过item embedding的相似度推荐相关实训
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
_, top_k_Item = hnsw.search(shixun_name, k=topk)
recommend_results = {cur_shixun_id: cur_shixun_name for cur_shixun_id, cur_shixun_name in top_k_Item}
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return recommend_results
def relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=10, verbose=True):
"""
通过dssm item embedding的相似度推荐相关实训
"""
start_time = datetime.now()
if verbose:
logger.info(f"本次需要进行推荐的实训: {shixun_name}")
recommend_results = {}
recommend_results.clear()
if (shixun_id not in dssm_item_emb_dict) or (shixun_id not in shixun_id_to_name_dict):
return recommend_results
# 取出物品向量
item_embs = dssm_item_emb_dict[shixun_id]
# reshape为二维
item_embs = item_embs.reshape(-1, embedding_dim)
# 找topk个相似的向量
D, I = dssm_item_faiss_model.search(np.ascontiguousarray(item_embs), topk + 1)
top_k_index = list(I.ravel())
top_k_item = {}
top_k_item.clear()
# 还原相似向量索引对应的用户
for index in top_k_index:
# 取出物品embedding索引对应的物品ID
cur_item_id = dssm_item_embedding_index_dict[index]
# 过滤第一个最相似的物品是自己
if cur_item_id != shixun_id:
# 取出物品名称
cur_item_name = shixun_id_to_name_dict[cur_item_id]
top_k_item[cur_item_id] = cur_item_name
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
if verbose:
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return top_k_item
if __name__ == '__main__':
shixun_name = test_shixun_name
print('*' * 50+"item embedding"+'*' * 50)
recommend_results = relevant_shixun_recommend_by_faiss(0, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
shixun_id = test_shixun_id
shixun_name = shixun_id_to_name_dict[shixun_id]
print('*' * 50+"DSSM"+'*' * 50)
recommend_results = relevant_shixun_recommend_by_dssm_item_embedding(shixun_id, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))
print('*' * 100)
recommend_results = relevant_shixun_recommend(shixun_id, shixun_name, topk=20)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4))