import os import sys sys.path.append(os.getcwd()) import pandas as pd import numpy as np from tqdm import tqdm import faiss import warnings import pickle import collections from config import logger from config import shixun_wordemb_i2i_sim_data from config import shixuns_embed_path from utils import get_file_size #只使用词向量进行相似度计算 tqdm.pandas() warnings.filterwarnings('ignore') def embedding_i2i_sim(item_emb_df, topk): """ 基于物品embedding计算相似性矩阵,返回topk个与其最相似的物品 """ # 加载之前保存的 if os.path.exists(shixun_wordemb_i2i_sim_data) and (get_file_size(shixun_wordemb_i2i_sim_data) > 1): item_sim_dict = pickle.load(open(shixun_wordemb_i2i_sim_data, 'rb')) return item_sim_dict # 物品索引与物品id的字典映射 item_idx_2_rawid_dict = dict(zip(item_emb_df.index, item_emb_df['shixun_id'])) # item_emb_cols = [x for x in item_emb_df.columns if 'emb' in x] item_emb_cols = [x for x in item_emb_df.columns if 'shixun_id' not in x] item_emb_np = np.ascontiguousarray(item_emb_df[item_emb_cols].values, dtype=np.float32) # 向量进行单位化 item_emb_np = item_emb_np / np.linalg.norm(item_emb_np, axis=1, keepdims=True) # 建立faiss索引 item_index = faiss.IndexFlatIP(item_emb_np.shape[1]) item_index.add(item_emb_np) # 相似度查询,给每个索引位置上的向量返回topk个item以及相似度 sim, idx = item_index.search(item_emb_np, topk) # 返回的是列表 # 将向量检索的结果保存成原始id的对应关系 item_sim_dict = collections.defaultdict(dict) for target_idx, sim_value_list, rele_idx_list in tqdm(zip(range(len(item_emb_np)), sim, idx)): target_raw_id = item_idx_2_rawid_dict[target_idx] # 从1开始是为了去掉物品本身, 所以最终获得的相似物品只有topk-1 for rele_idx, sim_value in zip(rele_idx_list[1:], sim_value_list[1:]): rele_raw_id = item_idx_2_rawid_dict[rele_idx] item_sim_dict[target_raw_id][rele_raw_id] = item_sim_dict.get(target_raw_id, {}).get(rele_raw_id, 0) + sim_value # 保存embedding i2i相似度矩阵 logger.info("保存word embedding i2i相似度矩阵") pickle.dump(item_sim_dict, open(shixun_wordemb_i2i_sim_data, 'wb')) return item_sim_dict if __name__ == '__main__': recall_item_num = 100 logger.info('生成物品embedding相似度矩阵') item_emb_df = pd.read_csv(shixuns_embed_path, sep='\t', encoding='utf-8') emb_i2i_sim = embedding_i2i_sim(item_emb_df, topk=recall_item_num)