You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
sys.path.append(os.getcwd())
import multiprocessing
from time import time
import pandas as pd
from gensim.models import Word2Vec
from config import subjects_keywords_path, ltp_model_path
from tqdm import tqdm
import jieba
from ltp import LTP
import torch
from config import JIEBA_TOKEN, LTP_TOKEN, user_dict_path, logger
from config import subjects_data_path, subject_faiss_w2v_path
from config import word2vec_dim
# 训练召回用的word2vec词向量实践课程包含字段"subject_name""sub_discipline_name""tag_names"
tqdm.pandas()
ltp = LTP(ltp_model_path)
if torch.cuda.is_available():
ltp.to("cuda")
# 加载用户自定义词典
if os.path.exists(subjects_keywords_path):
jieba.load_userdict(subjects_keywords_path)
with open(subjects_keywords_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
if os.path.exists(user_dict_path):
with open(user_dict_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
for word in user_dict_words:
jieba.add_word(word)
def tokenizer(sent, token_method=JIEBA_TOKEN):
"""
中文分词支持jieba和ltp两种方式
"""
if token_method == JIEBA_TOKEN:
seg = jieba.cut(sent)
result = ' '.join(seg)
elif token_method == LTP_TOKEN:
content = []
content.append(sent)
seg = ltp.pipeline(content, tasks=['cws'])['cws']
result = ''
for word in seg[0]:
if result == '':
result = word
else:
result = result + ' ' + word
return result
def read_data(file_path):
"""
读取数据并分词
"""
logger.info("Loading train data")
train = pd.read_csv(file_path, sep='\t', encoding='utf-8')
# 准备数据
subject_name = train["subject_name"]
sub_dis_name = train["sub_discipline_name"]
tags_name = train["tag_names"]
subject_name.fillna(value="",inplace=True)
sub_dis_name.fillna(value="",inplace=True)
tags_name.fillna(value="",inplace=True)
subject_text = subject_name+sub_dis_name+tags_name
logger.info("Starting tokenize...")
train['token_content'] = subject_text.progress_apply(tokenizer)
return train
def train_w2v(train, to_file):
# 所有有句子
sentences = [row.split() for row in train['token_content']]
# cpu的核数
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(min_count=1, # min_count为1确保一些专业词不成为OOV词
window=5,
vector_size=word2vec_dim,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores//2,
epochs=30)
t = time()
w2v_model.build_vocab(sentences)
logger.info('Time to build vocab: {} mins'.format(round((time() - t) / 60, 2)))
t = time()
w2v_model.train(sentences,
total_examples=w2v_model.corpus_count,
epochs=30,
report_delay=1)
logger.info('Time to train word2vec: {} mins'.format(round((time() - t) / 60, 2)))
if not os.path.exists(os.path.dirname(to_file)):
os.mkdir(os.path.dirname(to_file))
w2v_model.save(to_file)
logger.info('train word2vec finished.')
if __name__ == "__main__":
train = read_data(subjects_data_path)
train_w2v(train, subject_faiss_w2v_path)