You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 lines
4.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
sys.path.append(os.getcwd())
import pandas as pd
import numpy as np
from tqdm import tqdm
from ltp import LTP
import jieba
import csv
import torch
from gensim.models import KeyedVectors
from config import subjects_data_path, logger, user_dict_path
from config import subjects_keywords_path
from config import JIEBA_TOKEN, LTP_TOKEN, word2vec_dim
from config import ltp_model_path
from config import word2vec_model_path, subject_faiss_w2v_path, data_parent_path
#生成句向量,实践课程包含字段:"subject_name""sub_discipline_name""tag_names"
tqdm.pandas()
ltp = LTP(ltp_model_path)
if torch.cuda.is_available():
ltp.to("cuda")
logger.info("加载Word2Vec词向量")
w2v_model = KeyedVectors.load(word2vec_model_path)
fassi_w2v_model = KeyedVectors.load(subject_faiss_w2v_path)
# 加载用户自定义词典
if os.path.exists(subjects_keywords_path):
jieba.load_userdict(subjects_keywords_path)
with open(subjects_keywords_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
if os.path.exists(user_dict_path):
with open(user_dict_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
for word in user_dict_words:
jieba.add_word(word)
def tokenizer(sent, token_method=JIEBA_TOKEN, verbose=False):
"""
中文分词支持jieba和ltp两种方式
"""
if token_method == JIEBA_TOKEN:
seg = jieba.cut(sent)
result = ' '.join(seg)
elif token_method == LTP_TOKEN:
content = []
content.append(sent)
seg = ltp.pipeline(content, tasks=['cws'])['cws']
result = ''
for word in seg[0]:
if result == '':
result = word
else:
result = result + ' ' + word
if verbose == True:
logger.info(f"分词方式:{token_method}, 分词结果:{result}")
return result
def sentence_embedding(sentence, w2v_model, fassi_w2v_model, verbose=False):
'''
通过词向量均值的方式生成句向量
sentence: 待生成句向量的句子
w2v_model: word2vec模型
return: 句子中所有词向量的均值
'''
sentence = tokenizer(sentence, JIEBA_TOKEN,verbose)
embedding = []
for word in sentence.split():
if (word not in w2v_model.wv.index_to_key) and (word not in fassi_w2v_model.wv.index_to_key):
embedding.append(np.random.randn(1, word2vec_dim))
else:
if word in fassi_w2v_model.wv.index_to_key:
embedding.append(fassi_w2v_model.wv.get_vector(word))
else:
embedding.append(w2v_model.wv.get_vector(word))
# 所有词向量的均值为句向量
return np.mean(np.array(embedding), axis=0).reshape(1, -1)
def build_subjects_embedding():
'''
生成所有课程的句向量
'''
data = pd.read_csv(subjects_data_path, sep='\t', encoding='utf-8')
# 准备数据
subject_name = data["subject_name"]
sub_dis_name = data["sub_discipline_name"]
tags_name = data["tag_names"]
subject_name.fillna(value="",inplace=True)
sub_dis_name.fillna(value="",inplace=True)
tags_name.fillna(value="",inplace=True)
subject_text = subject_name+sub_dis_name+tags_name
logger.info('生成所有课程向量')
data['subject_name_vec'] = subject_text.progress_apply(
lambda x: sentence_embedding(x, w2v_model, fassi_w2v_model))
logger.info('检测所有课程向量的维度')
data['subject_name_vec'] = data['subject_name_vec'].progress_apply(
lambda x: x[0][0] if x.shape[1] != word2vec_dim else x)
data['subject_id'] = data['subject_id'].astype(int)
save_embedding_data(data)
def save_embedding_data(datas):
"""
保存embedding
"""
logger.info("保存生成的课程向量")
subjects_emb_data = open(data_parent_path + 'subjects_emb.csv', 'w', encoding='utf-8', newline="")
csv_out = csv.writer(subjects_emb_data, delimiter='\t')
# 先写入字段名
headers = ['subject_id']
subject_name_vec_100s = [("emb_" + str(i)) for i in range(100)]
for i in subject_name_vec_100s:
headers.append(i)
csv_out.writerow(headers)
# 再写入每行数据
for index, row in datas.iterrows():
subject_name_vec = row['subject_name_vec']
subject_name_vecs = subject_name_vec[0]
row_data = [row['subject_id']]
for i in range(len(subject_name_vecs)):
row_data.append(subject_name_vecs[i])
csv_out.writerow(row_data)
if __name__ == '__main__':
build_subjects_embedding()