You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
4.5 KiB

5 months ago
import numpy as np
import pickle
from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd())
import faiss
from gensim.models import Word2Vec
from gensim.models import KeyedVectors
import warnings
from config import logger, RANDOM_SEED
from config import embedding_dim, subject_item_w2v_emb_dict
from config import subject_user_w2v_emb_dict
from config import cpu_count, offline_mode
from config import subjects_item_word2vec_model_path
from config import subjects_user_word2vec_faiss_model_path
from config import subject_user_embedding_index_dict
from matching.subject.recall_comm import get_all_select_df
tqdm.pandas()
warnings.filterwarnings('ignore')
def user_embedding(item_id_list, w2v_model):
'''
通过物品向量均值的方式生成用户向量
item_id_list: 选择的物品列表
w2v_model: word2vec模型
return: 所有物品向量的均值
'''
embedding = []
for item_id in item_id_list:
if (item_id not in w2v_model.wv.index_to_key):
embedding.append(np.random.randn(1, embedding_dim))
else:
embedding.append(w2v_model.wv.get_vector(item_id))
# 所有词向量的均值为句向量
return np.mean(np.array(embedding), axis=0).reshape(1, -1)
def build_user_embedding(select_df, w2v_model):
"""
根据物品Word2Vec向量生成用户Word2Vec向量
把用户选择的每个物品ID当作一个词语
"""
select_df = select_df.sort_values(by=['created_timestamp'], axis=0, ascending=[False])
# 只有转换成字符串才可以进行训练
select_df['subject_id'] = select_df['subject_id'].astype(str)
# 选择的实训转换成列表
df_grouped = select_df.groupby(['user_id'])['subject_id'].progress_apply(
lambda x: list(x)).reset_index()
user_word2vec_dict = dict(zip(df_grouped['user_id'], df_grouped['subject_id']))
embedding_list = []
index = 0
user_embedding_index_dict = {}
for user_id, subject_id_list in tqdm(user_word2vec_dict.items()):
embedding = user_embedding(subject_id_list, w2v_model)
user_word2vec_dict[user_id] = embedding
user_embedding_index_dict.setdefault(index, [])
embedding_list.append(embedding)
user_embedding_index_dict[index].append((user_id, subject_id_list))
index += 1
pickle.dump(user_word2vec_dict, open(subject_user_w2v_emb_dict, 'wb'))
pickle.dump(user_embedding_index_dict, open(subject_user_embedding_index_dict, 'wb'))
vecs = np.stack(np.array(embedding_list)).reshape(-1, embedding_dim)
vecs = vecs.astype('float32')
# 构建索引
index = faiss.IndexFlatIP(embedding_dim)
index.add(vecs)
# 保存user embedding hnsw模型
faiss.write_index(index, subjects_user_word2vec_faiss_model_path)
def train_item_word2vec(select_df, embed_size=embedding_dim):
select_df = select_df.sort_values(['created_timestamp'])
# 只有转换成字符串才可以进行训练
select_df['subject_id'] = select_df['subject_id'].astype(str)
# 选择的实践课程转换成列表
docs = select_df.groupby(['user_id'])['subject_id'].progress_apply(lambda x: list(x)).reset_index()
docs = docs['subject_id'].values.tolist()
# 根据选择的实践课程列表构建Word2Vec向量
w2v_model = Word2Vec(sentences=docs,
vector_size=embed_size,
sg=1,
window=10,
seed=RANDOM_SEED,
workers=cpu_count,
min_count=1,
epochs=20,
hs=1)#hs=1会采用hierarchica·softmax技巧
item_w2v_emb_dict = {k: w2v_model.wv[k] for k in tqdm(select_df['subject_id'])}
w2v_model.save(subjects_item_word2vec_model_path)
pickle.dump(item_w2v_emb_dict, open(subject_item_w2v_emb_dict, 'wb'))
return w2v_model, item_w2v_emb_dict
if __name__ == '__main__':
logger.info("加载物品行为数据")
all_select_df = get_all_select_df(offline=offline_mode)
if not os.path.exists(subjects_item_word2vec_model_path):
logger.info('生成物品Word2Vec向量')
w2v_model, item_w2v_emb_dict = train_item_word2vec(all_select_df, embed_size=embedding_dim)
else:
logger.info("加载物品Word2Vec向量")
w2v_model = KeyedVectors.load(subjects_item_word2vec_model_path)
item_w2v_emb_dict = pickle.load(open(subject_item_w2v_emb_dict, 'rb'))
logger.info("生成用户Word2Vec向量")
build_user_embedding(all_select_df, w2v_model)
print("success!!!")