You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
2.0 KiB

5 months ago
from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd())
from config import shixuns_data_path,shixuns_bert_em_path
import pandas as pd
import logging
from transformers import AutoTokenizer, TFAutoModel
from config import bert_base_chinese
from utils import finalcut
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
tqdm.pandas()
# 加载大规模预训练Bert模型
bert_model = bert_base_chinese
tokenizer = AutoTokenizer.from_pretrained(bert_model)
model = TFAutoModel.from_pretrained(bert_model,output_hidden_states=True) # 模型是否返回所有隐藏状态。
shixun = pd.read_csv(shixuns_data_path,sep='\t',encoding='utf-8')
shixun = shixun.drop(['updated_at', 'status', 'publish_time', 'modify_time',
'reset_time', 'trainee', 'myshixuns_count','disciplines_id',
'disciplines_name', 'subject_id','created_at_ts'], axis=1)
#添加bert_em列
for i in tqdm(range(768)):
bert_em = "bert_em" + str(i)
shixun[bert_em] = 0.
# 准备数据
shixun_name = shixun["shixun_name"]
language = shixun["language"]
subject_name = shixun["subject_name"]
shixun_name.fillna(value="",inplace=True)
language.fillna(value="",inplace=True)
subject_name.fillna(value="",inplace=True)
shixun_text = shixun_name + language + subject_name
words = []
for i in tqdm(range(len(shixun_text))):
words.append(finalcut((shixun_text[i])))#删除句子中无效字符
def getbert_vec(word_list):
inputs = tokenizer(word_list, return_tensors="tf", padding="max_length", truncation=True, max_length=64)
outputs = model(inputs)
hidden_states = outputs[1] # 获得句子向量
return list(hidden_states.numpy()[0])
words_list = []
for i in tqdm(range(len(words))):
words_list.append(getbert_vec(words[i]))#获得句子的bert embedding向量
for i in tqdm(range(len(words_list))):
for j in range(len(words_list[i])):
column = "bert_em" + str(j)
shixun.loc[i, column] = words_list[i][j]
shixun.to_csv(shixuns_bert_em_path,sep='\t', index=False, header=True)