You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
4.5 KiB
126 lines
4.5 KiB
5 months ago
|
import numpy as np
|
||
|
import pickle
|
||
|
from tqdm import tqdm
|
||
|
import faiss
|
||
|
from gensim.models import Word2Vec
|
||
|
from gensim.models import KeyedVectors
|
||
|
import warnings
|
||
|
import os
|
||
|
import sys
|
||
|
sys.path.append(os.getcwd())
|
||
|
from config import logger, RANDOM_SEED
|
||
|
from config import embedding_dim, shixun_item_w2v_emb_dict
|
||
|
from config import shixun_user_w2v_emb_dict
|
||
|
from config import cpu_count, offline_mode
|
||
|
from config import shixuns_item_word2vec_model_path
|
||
|
from config import shixuns_user_word2vec_faiss_model_path
|
||
|
from config import shixun_user_embedding_index_dict
|
||
|
from matching.shixun.recall_comm import get_all_select_df
|
||
|
|
||
|
|
||
|
tqdm.pandas()
|
||
|
warnings.filterwarnings('ignore')
|
||
|
|
||
|
def user_embedding(item_id_list, w2v_model):
|
||
|
'''
|
||
|
通过物品向量均值的方式生成用户向量
|
||
|
item_id_list: 选择的物品列表
|
||
|
w2v_model: word2vec模型
|
||
|
return: 所有物品向量的均值
|
||
|
'''
|
||
|
embedding = []
|
||
|
for item_id in item_id_list:
|
||
|
if (item_id not in w2v_model.wv.index_to_key):
|
||
|
embedding.append(np.random.randn(1, embedding_dim))
|
||
|
else:
|
||
|
embedding.append(w2v_model.wv.get_vector(item_id))
|
||
|
|
||
|
# 所有词向量的均值为句向量
|
||
|
return np.mean(np.array(embedding), axis=0).reshape(1, -1)
|
||
|
|
||
|
|
||
|
def build_user_embedding(select_df, w2v_model):
|
||
|
"""
|
||
|
根据物品Word2Vec向量生成用户Word2Vec向量
|
||
|
把用户选择的每个物品ID当作一个词语
|
||
|
"""
|
||
|
select_df = select_df.sort_values(['created_timestamp'])
|
||
|
|
||
|
# 只有转换成字符串才可以进行训练
|
||
|
select_df['shixun_id'] = select_df['shixun_id'].astype(str)
|
||
|
|
||
|
# 选择的实训转换成列表
|
||
|
df_grouped = select_df.groupby(['user_id'])['shixun_id'].progress_apply(
|
||
|
lambda x: list(x)).reset_index()
|
||
|
|
||
|
user_word2vec_dict = dict(zip(df_grouped['user_id'], df_grouped['shixun_id']))
|
||
|
embedding_list = []
|
||
|
|
||
|
index = 0
|
||
|
user_embedding_index_dict = {}
|
||
|
|
||
|
for user_id, shixun_id_list in tqdm(user_word2vec_dict.items()):
|
||
|
embedding = user_embedding(shixun_id_list, w2v_model)
|
||
|
user_word2vec_dict[user_id] = embedding
|
||
|
user_embedding_index_dict.setdefault(index, [])
|
||
|
embedding_list.append(embedding)
|
||
|
user_embedding_index_dict[index].append((user_id, shixun_id_list))
|
||
|
index += 1
|
||
|
|
||
|
pickle.dump(user_word2vec_dict, open(shixun_user_w2v_emb_dict, 'wb'))
|
||
|
pickle.dump(user_embedding_index_dict, open(shixun_user_embedding_index_dict, 'wb'))
|
||
|
|
||
|
vecs = np.stack(np.array(embedding_list)).reshape(-1, embedding_dim)
|
||
|
vecs = vecs.astype('float32')
|
||
|
|
||
|
# 构建索引
|
||
|
index = faiss.IndexFlatIP(embedding_dim)
|
||
|
index.add(vecs)
|
||
|
|
||
|
# 保存user embedding hnsw模型
|
||
|
faiss.write_index(index, shixuns_user_word2vec_faiss_model_path)
|
||
|
|
||
|
|
||
|
def train_item_word2vec(select_df, embed_size=embedding_dim):
|
||
|
select_df = select_df.sort_values(['created_timestamp'])
|
||
|
|
||
|
# 只有转换成字符串才可以进行训练
|
||
|
select_df['shixun_id'] = select_df['shixun_id'].astype(str)
|
||
|
|
||
|
# 选择的实训转换成列表
|
||
|
docs = select_df.groupby(['user_id'])['shixun_id'].progress_apply(lambda x: list(x)).reset_index()
|
||
|
docs = docs['shixun_id'].values.tolist()
|
||
|
|
||
|
# 根据选择的实训列表构建Word2Vec向量
|
||
|
w2v_model = Word2Vec(sentences=docs,
|
||
|
vector_size=embed_size,
|
||
|
sg=1,
|
||
|
window=5,
|
||
|
seed=RANDOM_SEED,
|
||
|
workers=cpu_count,
|
||
|
min_count=1,
|
||
|
epochs=20,
|
||
|
hs=1
|
||
|
)#hs=1会采用hierarchica·softmax技巧
|
||
|
|
||
|
item_w2v_emb_dict = {k: w2v_model.wv[k] for k in tqdm(select_df['shixun_id'])}
|
||
|
|
||
|
w2v_model.save(shixuns_item_word2vec_model_path)
|
||
|
pickle.dump(item_w2v_emb_dict, open(shixun_item_w2v_emb_dict, 'wb'))
|
||
|
return w2v_model, item_w2v_emb_dict
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
logger.info("加载物品行为数据")
|
||
|
all_select_df = get_all_select_df(offline=offline_mode)
|
||
|
|
||
|
if not os.path.exists(shixuns_item_word2vec_model_path):
|
||
|
logger.info('生成物品Word2Vec向量')
|
||
|
w2v_model, item_w2v_emb_dict = train_item_word2vec(all_select_df, embed_size=embedding_dim)
|
||
|
else:
|
||
|
logger.info("加载物品Word2Vec向量")
|
||
|
w2v_model = KeyedVectors.load(shixuns_item_word2vec_model_path)
|
||
|
item_w2v_emb_dict = pickle.load(open(shixun_item_w2v_emb_dict, 'rb'))
|
||
|
|
||
|
logger.info("生成用户Word2Vec向量")
|
||
|
build_user_embedding(all_select_df, w2v_model)
|