You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
3.6 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing
from time import time
import pandas as pd
from gensim.models import Word2Vec
from config import shixuns_keywords_path, ltp_model_path
from tqdm import tqdm
import jieba
from ltp import LTP
import torch
from config import JIEBA_TOKEN, LTP_TOKEN, user_dict_path, logger
from config import shixuns_data_path, shixun_faiss_w2v_path
from config import word2vec_dim
# 训练召回用的word2vec词向量,实训包含字段:"shixun_name""language""subject_name"
tqdm.pandas()
ltp = LTP(ltp_model_path)
if torch.cuda.is_available():
ltp.to("cuda")
# 加载用户自定义词典
if os.path.exists(shixuns_keywords_path):
jieba.load_userdict(shixuns_keywords_path)
with open(shixuns_keywords_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
if os.path.exists(user_dict_path):
with open(user_dict_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
for word in user_dict_words:
jieba.add_word(word)
def tokenizer(sent, token_method=JIEBA_TOKEN):
"""
中文分词支持jieba和ltp两种方式
"""
if token_method == JIEBA_TOKEN:
seg = jieba.cut(sent)
result = ' '.join(seg)
elif token_method == LTP_TOKEN:
content = []
content.append(sent)
seg = ltp.pipeline(content, tasks=['cws'])['cws']
result = ''
for word in seg[0]:
if result == '':
result = word
else:
result = result + ' ' + word
return result
def read_data(file_path):
"""
读取数据并分词
"""
logger.info("Loading train data")
train = pd.read_csv(file_path, sep='\t', encoding='utf-8')
logger.info("Starting tokenize...")
# 准备数据
shixun_name = train["shixun_name"]
language = train["language"]
subject_name = train["subject_name"]
#空值填充,否则连接三个文本有一方出现空值,整体为空
shixun_name.fillna(value="",inplace=True)
language.fillna(value="",inplace=True)
subject_name.fillna(value="",inplace=True)
shixun_text = shixun_name+language+subject_name
train['token_content'] = shixun_text.progress_apply(tokenizer)
return train
def train_w2v(train, to_file):
# 所有有句子
sentences = [row.split() for row in train['token_content']]
# cpu的核数
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(min_count=1, # min_count为1确保一些专业词不成为OOV词
window=5,
vector_size=word2vec_dim,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores//2,
epochs=20,
hs=1)
t = time()
w2v_model.build_vocab(sentences)
logger.info('Time to build vocab: {} mins'.format(round((time() - t) / 60, 2)))
t = time()
w2v_model.train(sentences,
total_examples=w2v_model.corpus_count,
epochs=30,
report_delay=1)
logger.info('Time to train word2vec: {} mins'.format(round((time() - t) / 60, 2)))
if not os.path.exists(os.path.dirname(to_file)):
os.mkdir(os.path.dirname(to_file))
w2v_model.save(to_file)
logger.info('train word2vec finished.')
if __name__ == "__main__":
train = read_data(shixuns_data_path)
train_w2v(train, shixun_faiss_w2v_path)