|
|
import os
|
|
|
import sys
|
|
|
sys.path.append(os.getcwd())
|
|
|
import multiprocessing
|
|
|
from time import time
|
|
|
import pandas as pd
|
|
|
from gensim.models import Word2Vec
|
|
|
from config import shixuns_keywords_path, ltp_model_path
|
|
|
from tqdm import tqdm
|
|
|
import jieba
|
|
|
from ltp import LTP
|
|
|
import torch
|
|
|
from config import JIEBA_TOKEN, LTP_TOKEN, user_dict_path, logger
|
|
|
from config import shixuns_data_path, shixun_faiss_w2v_path
|
|
|
from config import word2vec_dim
|
|
|
|
|
|
# 训练召回用的word2vec词向量,实训包含字段:"shixun_name","language","subject_name"
|
|
|
|
|
|
tqdm.pandas()
|
|
|
|
|
|
ltp = LTP(ltp_model_path)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
ltp.to("cuda")
|
|
|
|
|
|
# 加载用户自定义词典
|
|
|
if os.path.exists(shixuns_keywords_path):
|
|
|
jieba.load_userdict(shixuns_keywords_path)
|
|
|
with open(shixuns_keywords_path, 'r', encoding='utf-8') as f:
|
|
|
user_dict_words = f.read().split()
|
|
|
ltp.add_words(user_dict_words)
|
|
|
|
|
|
if os.path.exists(user_dict_path):
|
|
|
with open(user_dict_path, 'r', encoding='utf-8') as f:
|
|
|
user_dict_words = f.read().split()
|
|
|
ltp.add_words(user_dict_words)
|
|
|
for word in user_dict_words:
|
|
|
jieba.add_word(word)
|
|
|
|
|
|
def tokenizer(sent, token_method=JIEBA_TOKEN):
|
|
|
"""
|
|
|
中文分词,支持jieba和ltp两种方式
|
|
|
"""
|
|
|
if token_method == JIEBA_TOKEN:
|
|
|
seg = jieba.cut(sent)
|
|
|
result = ' '.join(seg)
|
|
|
elif token_method == LTP_TOKEN:
|
|
|
content = []
|
|
|
content.append(sent)
|
|
|
seg = ltp.pipeline(content, tasks=['cws'])['cws']
|
|
|
result = ''
|
|
|
for word in seg[0]:
|
|
|
if result == '':
|
|
|
result = word
|
|
|
else:
|
|
|
result = result + ' ' + word
|
|
|
return result
|
|
|
|
|
|
def read_data(file_path):
|
|
|
"""
|
|
|
读取数据并分词
|
|
|
"""
|
|
|
logger.info("Loading train data")
|
|
|
train = pd.read_csv(file_path, sep='\t', encoding='utf-8')
|
|
|
|
|
|
logger.info("Starting tokenize...")
|
|
|
|
|
|
# 准备数据
|
|
|
shixun_name = train["shixun_name"]
|
|
|
language = train["language"]
|
|
|
subject_name = train["subject_name"]
|
|
|
|
|
|
#空值填充,否则连接三个文本有一方出现空值,整体为空
|
|
|
shixun_name.fillna(value="",inplace=True)
|
|
|
language.fillna(value="",inplace=True)
|
|
|
subject_name.fillna(value="",inplace=True)
|
|
|
|
|
|
shixun_text = shixun_name+language+subject_name
|
|
|
|
|
|
train['token_content'] = shixun_text.progress_apply(tokenizer)
|
|
|
return train
|
|
|
|
|
|
def train_w2v(train, to_file):
|
|
|
|
|
|
# 所有有句子
|
|
|
sentences = [row.split() for row in train['token_content']]
|
|
|
|
|
|
# cpu的核数
|
|
|
cores = multiprocessing.cpu_count()
|
|
|
|
|
|
w2v_model = Word2Vec(min_count=1, # min_count为1确保一些专业词不成为OOV词
|
|
|
window=5,
|
|
|
vector_size=word2vec_dim,
|
|
|
sample=6e-5,
|
|
|
alpha=0.03,
|
|
|
min_alpha=0.0007,
|
|
|
negative=15,
|
|
|
workers=cores//2,
|
|
|
epochs=20,
|
|
|
hs=1)
|
|
|
|
|
|
t = time()
|
|
|
w2v_model.build_vocab(sentences)
|
|
|
logger.info('Time to build vocab: {} mins'.format(round((time() - t) / 60, 2)))
|
|
|
|
|
|
t = time()
|
|
|
w2v_model.train(sentences,
|
|
|
total_examples=w2v_model.corpus_count,
|
|
|
epochs=30,
|
|
|
report_delay=1)
|
|
|
logger.info('Time to train word2vec: {} mins'.format(round((time() - t) / 60, 2)))
|
|
|
|
|
|
if not os.path.exists(os.path.dirname(to_file)):
|
|
|
os.mkdir(os.path.dirname(to_file))
|
|
|
|
|
|
w2v_model.save(to_file)
|
|
|
|
|
|
logger.info('train word2vec finished.')
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
train = read_data(shixuns_data_path)
|
|
|
train_w2v(train, shixun_faiss_w2v_path) |