import os import sys sys.path.append(os.getcwd()) import numpy as np import faiss import pickle from datetime import datetime from config import logger from config import embedding_dim from config import test_user_id from config import shixun_fm_user_emb_dict from config import shixun_fm_user_embedding_data from config import shixun_fm_user_faiss_model_path from config import shixun_fm_user_embedding_index_dict from config import shixun_fm_item_emb_dict from config import shixun_fm_item_embedding_data from config import shixun_fm_item_faiss_model_path from config import shixun_fm_item_embedding_index_dict logger.info('加载FM召回数据') fm_user_emb_dict = pickle.load(open(shixun_fm_user_emb_dict , 'rb')) fm_user_embedding_data = pickle.load(open(shixun_fm_user_embedding_data, 'rb')) fm_user_embedding_index_dict = pickle.load(open(shixun_fm_user_embedding_index_dict, 'rb')) fm_item_emb_dict = pickle.load(open(shixun_fm_item_emb_dict, 'rb')) fm_item_embedding_data = pickle.load(open(shixun_fm_item_embedding_data, 'rb')) fm_item_embedding_index_dict = pickle.load(open(shixun_fm_item_embedding_index_dict, 'rb')) logger.info('加载FM召回模型') fm_user_faiss_model = faiss.read_index(shixun_fm_user_faiss_model_path) fm_item_faiss_model = faiss.read_index(shixun_fm_item_faiss_model_path) def fm_recall(user_id, topk=20): """ 通过fm user embedding的相似度推荐学习伙伴 """ start_time = datetime.now() logger.info(f"本次需要进行fm召回的用户ID: {user_id}") recommend_results = {} recommend_results.clear() if (user_id not in fm_user_emb_dict): return recommend_results # 取出用户向量 user_embs = fm_user_emb_dict[user_id] # reshape为二维 user_embs = user_embs.reshape(-1, 32) # 找topk个相似的物品向量 D, I = fm_item_faiss_model.search(np.ascontiguousarray(user_embs), topk + 1) top_k_index = list(I.ravel()) top_k_sim = list(D.ravel()) top_k_item = {} top_k_item.clear() # 还原相似向量索引对应的用户 for i, index in enumerate(top_k_index): cur_user_id = fm_item_embedding_index_dict[index] # 过滤第一个最相似的用户是自己 if cur_user_id != user_id: top_k_item[cur_user_id] = top_k_sim[i] top_k_item = sorted(top_k_item.items(), key=lambda x: x[1], reverse=True)[:topk] # 计算耗时毫秒 end_time = datetime.utcnow() cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3) logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒") return top_k_item if __name__ == '__main__': recall_results = fm_recall(user_id=test_user_id, topk=20) print(recall_results)