|
|
import os
|
|
|
import sys
|
|
|
sys.path.append(os.getcwd())
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
import faiss
|
|
|
import warnings
|
|
|
import pickle
|
|
|
import collections
|
|
|
from config import logger
|
|
|
from config import shixun_wordemb_i2i_sim_data
|
|
|
from config import shixuns_embed_path
|
|
|
from utils import get_file_size
|
|
|
|
|
|
#只使用词向量进行相似度计算
|
|
|
|
|
|
tqdm.pandas()
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
def embedding_i2i_sim(item_emb_df, topk):
|
|
|
"""
|
|
|
基于物品embedding计算相似性矩阵,返回topk个与其最相似的物品
|
|
|
"""
|
|
|
# 加载之前保存的
|
|
|
if os.path.exists(shixun_wordemb_i2i_sim_data) and (get_file_size(shixun_wordemb_i2i_sim_data) > 1):
|
|
|
item_sim_dict = pickle.load(open(shixun_wordemb_i2i_sim_data, 'rb'))
|
|
|
return item_sim_dict
|
|
|
|
|
|
# 物品索引与物品id的字典映射
|
|
|
item_idx_2_rawid_dict = dict(zip(item_emb_df.index, item_emb_df['shixun_id']))
|
|
|
|
|
|
# item_emb_cols = [x for x in item_emb_df.columns if 'emb' in x]
|
|
|
item_emb_cols = [x for x in item_emb_df.columns if 'shixun_id' not in x]
|
|
|
item_emb_np = np.ascontiguousarray(item_emb_df[item_emb_cols].values, dtype=np.float32)
|
|
|
|
|
|
# 向量进行单位化
|
|
|
item_emb_np = item_emb_np / np.linalg.norm(item_emb_np, axis=1, keepdims=True)
|
|
|
|
|
|
# 建立faiss索引
|
|
|
item_index = faiss.IndexFlatIP(item_emb_np.shape[1])
|
|
|
item_index.add(item_emb_np)
|
|
|
|
|
|
# 相似度查询,给每个索引位置上的向量返回topk个item以及相似度
|
|
|
sim, idx = item_index.search(item_emb_np, topk) # 返回的是列表
|
|
|
|
|
|
# 将向量检索的结果保存成原始id的对应关系
|
|
|
item_sim_dict = collections.defaultdict(dict)
|
|
|
for target_idx, sim_value_list, rele_idx_list in tqdm(zip(range(len(item_emb_np)), sim, idx)):
|
|
|
target_raw_id = item_idx_2_rawid_dict[target_idx]
|
|
|
|
|
|
# 从1开始是为了去掉物品本身, 所以最终获得的相似物品只有topk-1
|
|
|
for rele_idx, sim_value in zip(rele_idx_list[1:], sim_value_list[1:]):
|
|
|
rele_raw_id = item_idx_2_rawid_dict[rele_idx]
|
|
|
item_sim_dict[target_raw_id][rele_raw_id] = item_sim_dict.get(target_raw_id, {}).get(rele_raw_id, 0) + sim_value
|
|
|
|
|
|
# 保存embedding i2i相似度矩阵
|
|
|
logger.info("保存word embedding i2i相似度矩阵")
|
|
|
pickle.dump(item_sim_dict, open(shixun_wordemb_i2i_sim_data, 'wb'))
|
|
|
return item_sim_dict
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
recall_item_num = 100
|
|
|
logger.info('生成物品embedding相似度矩阵')
|
|
|
item_emb_df = pd.read_csv(shixuns_embed_path, sep='\t', encoding='utf-8')
|
|
|
emb_i2i_sim = embedding_i2i_sim(item_emb_df, topk=recall_item_num) |