You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.6 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing
from time import time
import pandas as pd
from gensim.models import Word2Vec
from config import subjects_keywords_path, ltp_model_path
from tqdm import tqdm
import jieba
from ltp import LTP
import torch
from config import JIEBA_TOKEN, LTP_TOKEN, user_dict_path, logger
from config import subjects_data_path, subject_faiss_w2v_path
from config import word2vec_dim
# 训练召回用的word2vec词向量实践课程包含字段"subject_name""sub_discipline_name""tag_names"
tqdm.pandas()
ltp = LTP(ltp_model_path)
if torch.cuda.is_available():
ltp.to("cuda")
# 加载用户自定义词典
if os.path.exists(subjects_keywords_path):
jieba.load_userdict(subjects_keywords_path)
with open(subjects_keywords_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
if os.path.exists(user_dict_path):
with open(user_dict_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
for word in user_dict_words:
jieba.add_word(word)
def tokenizer(sent, token_method=JIEBA_TOKEN):
"""
中文分词支持jieba和ltp两种方式
"""
if token_method == JIEBA_TOKEN:
seg = jieba.cut(sent)
result = ' '.join(seg)
elif token_method == LTP_TOKEN:
content = []
content.append(sent)
seg = ltp.pipeline(content, tasks=['cws'])['cws']
result = ''
for word in seg[0]:
if result == '':
result = word
else:
result = result + ' ' + word
return result
def read_data(file_path):
"""
读取数据并分词
"""
logger.info("Loading train data")
train = pd.read_csv(file_path, sep='\t', encoding='utf-8')
# 准备数据
subject_name = train["subject_name"]
sub_dis_name = train["sub_discipline_name"]
tags_name = train["tag_names"]
subject_name.fillna(value="",inplace=True)
sub_dis_name.fillna(value="",inplace=True)
tags_name.fillna(value="",inplace=True)
subject_text = subject_name+sub_dis_name+tags_name
logger.info("Starting tokenize...")
train['token_content'] = subject_text.progress_apply(tokenizer)
return train
def train_w2v(train, to_file):
# 所有有句子
sentences = [row.split() for row in train['token_content']]
# cpu的核数
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(min_count=1, # min_count为1确保一些专业词不成为OOV词
window=5,
vector_size=word2vec_dim,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores//2,
epochs=30)
t = time()
w2v_model.build_vocab(sentences)
logger.info('Time to build vocab: {} mins'.format(round((time() - t) / 60, 2)))
t = time()
w2v_model.train(sentences,
total_examples=w2v_model.corpus_count,
epochs=30,
report_delay=1)
logger.info('Time to train word2vec: {} mins'.format(round((time() - t) / 60, 2)))
if not os.path.exists(os.path.dirname(to_file)):
os.mkdir(os.path.dirname(to_file))
w2v_model.save(to_file)
logger.info('train word2vec finished.')
if __name__ == "__main__":
train = read_data(subjects_data_path)
train_w2v(train, subject_faiss_w2v_path)