You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.6 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import pandas as pd
import numpy as np
from tqdm import tqdm
import faiss
import warnings
import pickle
import collections
from config import logger
from config import shixun_wordemb_i2i_sim_data
from config import shixuns_embed_path
from utils import get_file_size
#只使用词向量进行相似度计算
tqdm.pandas()
warnings.filterwarnings('ignore')
def embedding_i2i_sim(item_emb_df, topk):
"""
基于物品embedding计算相似性矩阵返回topk个与其最相似的物品
"""
# 加载之前保存的
if os.path.exists(shixun_wordemb_i2i_sim_data) and (get_file_size(shixun_wordemb_i2i_sim_data) > 1):
item_sim_dict = pickle.load(open(shixun_wordemb_i2i_sim_data, 'rb'))
return item_sim_dict
# 物品索引与物品id的字典映射
item_idx_2_rawid_dict = dict(zip(item_emb_df.index, item_emb_df['shixun_id']))
# item_emb_cols = [x for x in item_emb_df.columns if 'emb' in x]
item_emb_cols = [x for x in item_emb_df.columns if 'shixun_id' not in x]
item_emb_np = np.ascontiguousarray(item_emb_df[item_emb_cols].values, dtype=np.float32)
# 向量进行单位化
item_emb_np = item_emb_np / np.linalg.norm(item_emb_np, axis=1, keepdims=True)
# 建立faiss索引
item_index = faiss.IndexFlatIP(item_emb_np.shape[1])
item_index.add(item_emb_np)
# 相似度查询给每个索引位置上的向量返回topk个item以及相似度
sim, idx = item_index.search(item_emb_np, topk) # 返回的是列表
# 将向量检索的结果保存成原始id的对应关系
item_sim_dict = collections.defaultdict(dict)
for target_idx, sim_value_list, rele_idx_list in tqdm(zip(range(len(item_emb_np)), sim, idx)):
target_raw_id = item_idx_2_rawid_dict[target_idx]
# 从1开始是为了去掉物品本身, 所以最终获得的相似物品只有topk-1
for rele_idx, sim_value in zip(rele_idx_list[1:], sim_value_list[1:]):
rele_raw_id = item_idx_2_rawid_dict[rele_idx]
item_sim_dict[target_raw_id][rele_raw_id] = item_sim_dict.get(target_raw_id, {}).get(rele_raw_id, 0) + sim_value
# 保存embedding i2i相似度矩阵
logger.info("保存word embedding i2i相似度矩阵")
pickle.dump(item_sim_dict, open(shixun_wordemb_i2i_sim_data, 'wb'))
return item_sim_dict
if __name__ == '__main__':
recall_item_num = 100
logger.info('生成物品embedding相似度矩阵')
item_emb_df = pd.read_csv(shixuns_embed_path, sep='\t', encoding='utf-8')
emb_i2i_sim = embedding_i2i_sim(item_emb_df, topk=recall_item_num)