You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
sys.path.append(os.getcwd())
import multiprocessing
from time import time
import pandas as pd
from gensim.models import Word2Vec
from config import shixuns_keywords_path, ltp_model_path
from tqdm import tqdm
import jieba
from ltp import LTP
import torch
from config import JIEBA_TOKEN, LTP_TOKEN, user_dict_path, logger
from config import shixuns_data_path, shixun_faiss_w2v_path
from config import word2vec_dim
# 训练召回用的word2vec词向量,实训包含字段:"shixun_name""language""subject_name"
tqdm.pandas()
ltp = LTP(ltp_model_path)
if torch.cuda.is_available():
ltp.to("cuda")
# 加载用户自定义词典
if os.path.exists(shixuns_keywords_path):
jieba.load_userdict(shixuns_keywords_path)
with open(shixuns_keywords_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
if os.path.exists(user_dict_path):
with open(user_dict_path, 'r', encoding='utf-8') as f:
user_dict_words = f.read().split()
ltp.add_words(user_dict_words)
for word in user_dict_words:
jieba.add_word(word)
def tokenizer(sent, token_method=JIEBA_TOKEN):
"""
中文分词支持jieba和ltp两种方式
"""
if token_method == JIEBA_TOKEN:
seg = jieba.cut(sent)
result = ' '.join(seg)
elif token_method == LTP_TOKEN:
content = []
content.append(sent)
seg = ltp.pipeline(content, tasks=['cws'])['cws']
result = ''
for word in seg[0]:
if result == '':
result = word
else:
result = result + ' ' + word
return result
def read_data(file_path):
"""
读取数据并分词
"""
logger.info("Loading train data")
train = pd.read_csv(file_path, sep='\t', encoding='utf-8')
logger.info("Starting tokenize...")
# 准备数据
shixun_name = train["shixun_name"]
language = train["language"]
subject_name = train["subject_name"]
#空值填充,否则连接三个文本有一方出现空值,整体为空
shixun_name.fillna(value="",inplace=True)
language.fillna(value="",inplace=True)
subject_name.fillna(value="",inplace=True)
shixun_text = shixun_name+language+subject_name
train['token_content'] = shixun_text.progress_apply(tokenizer)
return train
def train_w2v(train, to_file):
# 所有有句子
sentences = [row.split() for row in train['token_content']]
# cpu的核数
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(min_count=1, # min_count为1确保一些专业词不成为OOV词
window=5,
vector_size=word2vec_dim,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores//2,
epochs=20,
hs=1)
t = time()
w2v_model.build_vocab(sentences)
logger.info('Time to build vocab: {} mins'.format(round((time() - t) / 60, 2)))
t = time()
w2v_model.train(sentences,
total_examples=w2v_model.corpus_count,
epochs=30,
report_delay=1)
logger.info('Time to train word2vec: {} mins'.format(round((time() - t) / 60, 2)))
if not os.path.exists(os.path.dirname(to_file)):
os.mkdir(os.path.dirname(to_file))
w2v_model.save(to_file)
logger.info('train word2vec finished.')
if __name__ == "__main__":
train = read_data(shixuns_data_path)
train_w2v(train, shixun_faiss_w2v_path)