You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
7.8 KiB
230 lines
7.8 KiB
2 years ago
|
import os
|
||
|
import json
|
||
|
import torch
|
||
|
import os
|
||
|
import gensim
|
||
|
import torch
|
||
|
import json
|
||
|
import multiprocessing
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
from time import time
|
||
|
from gensim.models import Word2Vec
|
||
|
from gensim.models.phrases import Phraser
|
||
|
from gensim.models.phrases import Phrases
|
||
|
from gensim.models import KeyedVectors
|
||
|
|
||
|
|
||
|
import os
|
||
|
import json
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
from collections import Counter
|
||
|
import warnings
|
||
|
warnings.filterwarnings("ignore")
|
||
|
|
||
|
def extract_three_cls_data(data_path,save_path):
|
||
|
map_path = 'map.json'
|
||
|
data = pd.read_csv(data_path, sep='\t').dropna()
|
||
|
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
|
||
|
cls_data.index = range(len(cls_data))
|
||
|
print(Counter(cls_data['label']))
|
||
|
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
|
||
|
label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))}
|
||
|
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
|
||
|
if not os.path.exists(label_map_json):
|
||
|
with open(map_path, 'w', encoding='utf-8') as f:
|
||
|
f.write(label_map_json)
|
||
|
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
|
||
|
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
|
||
|
stopwords = f.readlines()
|
||
|
stopwords = [i.strip() for i in stopwords]
|
||
|
cls_data['text_seg'] = ''
|
||
|
for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
|
||
|
words = row['text'].split(' ')
|
||
|
out_str = ''
|
||
|
for word in words:
|
||
|
if word not in stopwords:
|
||
|
out_str += word
|
||
|
out_str += ' '
|
||
|
cls_data['text_seg'][idx] = out_str
|
||
|
cls_data.to_csv(save_path, index=False)
|
||
|
|
||
|
def build_word2id(lists):
|
||
|
maps = {}
|
||
|
for item in lists:
|
||
|
if item not in maps:
|
||
|
maps[item] = len(maps)
|
||
|
return maps
|
||
|
|
||
|
|
||
|
def build_data(train_data, word2id_map, max_length):
|
||
|
data = train_data['text_seg']
|
||
|
train_list = []
|
||
|
label_list = train_data['textcnn_label']
|
||
|
for line in data:
|
||
|
train_word_list = line.split(' ')
|
||
|
train_line_id = []
|
||
|
for word in train_word_list:
|
||
|
id = word2id_map[word]
|
||
|
train_line_id.append(id)
|
||
|
length = len(train_line_id)
|
||
|
if length > max_length + 1:
|
||
|
train_line_id = train_line_id[:max_length + 1]
|
||
|
if length < max_length + 1:
|
||
|
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
|
||
|
train_list.append(train_line_id)
|
||
|
return train_list, label_list
|
||
|
|
||
|
|
||
|
def filter_stopwords(data_path, save_path):
|
||
|
if not os.path.exists(save_path):
|
||
|
extract_three_cls_data(data_path, save_path)
|
||
|
|
||
|
|
||
|
def concat_all_data(train_path, dev_path, test_path):
|
||
|
data_train = pd.read_csv(train_path)
|
||
|
data_dev = pd.read_csv(dev_path)
|
||
|
data_test = pd.read_csv(test_path)
|
||
|
data = pd.concat([data_train, data_dev, data_test])
|
||
|
data.index = range(len(data))
|
||
|
return data
|
||
|
|
||
|
|
||
|
def gen_word2id(train_path, dev_path, test_path):
|
||
|
data = concat_all_data(train_path, dev_path, test_path)
|
||
|
word2id_path = 'word2id.json'
|
||
|
id2word_path = 'id2word.json'
|
||
|
if not os.path.exists(word2id_path):
|
||
|
data_lines = data['text_seg']
|
||
|
words_list = []
|
||
|
for line in tqdm(data_lines, desc='gen word2id'):
|
||
|
words = line.split(' ')
|
||
|
words_list.extend(words)
|
||
|
word2id = build_word2id(words_list)
|
||
|
word2id['PAD'] = len(word2id)
|
||
|
id2word = {word2id[w]: w for w in word2id}
|
||
|
with open(word2id_path, 'w', encoding='utf-8') as f:
|
||
|
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
|
||
|
with open(id2word_path, 'w', encoding='utf-8') as f:
|
||
|
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
|
||
|
else:
|
||
|
with open(word2id_path, 'r', encoding='utf-8') as f:
|
||
|
word2id = json.load(f)
|
||
|
with open(id2word_path, 'r', encoding='utf-8') as f:
|
||
|
id2word = json.load(f)
|
||
|
return word2id, id2word
|
||
|
|
||
|
|
||
|
def process_data(data_path, word2id, max_length):
|
||
|
data = pd.read_csv(data_path)
|
||
|
train_list, label_list = build_data(data, word2id, max_length)
|
||
|
return train_list, label_list
|
||
|
|
||
|
def prepare_data(max_length):
|
||
|
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
|
||
|
train_save_path = 'train.csv'
|
||
|
filter_stopwords(train_data_path, train_save_path)
|
||
|
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
|
||
|
dev_save_path = 'dev.csv'
|
||
|
filter_stopwords(dev_data_path, dev_save_path)
|
||
|
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
|
||
|
test_save_path = 'test.csv'
|
||
|
filter_stopwords(test_data_path, test_save_path)
|
||
|
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
|
||
|
X_train, y_train = process_data(train_save_path, word2id, max_length)
|
||
|
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
|
||
|
X_test, y_test = process_data(test_save_path, word2id, max_length)
|
||
|
return X_train,y_train, X_dev, y_dev, X_test, y_test
|
||
|
|
||
|
|
||
|
|
||
|
def load_w2v():
|
||
|
train_save_path = 'train.csv'
|
||
|
dev_save_path = 'dev.csv'
|
||
|
test_save_path = 'test.csv'
|
||
|
data = concat_all_data(train_save_path, dev_save_path, test_save_path)
|
||
|
model_save_path = 'w2v_model.bin'
|
||
|
vec_save_path = 'w2v_model.txt'
|
||
|
if not os.path.exists(vec_save_path):
|
||
|
sent = [str(row).split(' ') for row in data['text_seg']]
|
||
|
phrases = Phrases(sent, min_count=5, progress_per=10000)
|
||
|
bigram = Phraser(phrases)
|
||
|
sentence = bigram[sent]
|
||
|
cores = multiprocessing.cpu_count()
|
||
|
w2v_model = Word2Vec(
|
||
|
min_count=2,
|
||
|
window=2,
|
||
|
size=300,
|
||
|
sample=6e-5,
|
||
|
alpha=0.03,
|
||
|
min_alpha=0.0007,
|
||
|
negative=15,
|
||
|
workers=cores - 1,
|
||
|
iter=7)
|
||
|
t0 = time()
|
||
|
w2v_model.build_vocab(sentence)
|
||
|
t1 = time()
|
||
|
print('build vocab cost time: {}s'.format(t1 - t0))
|
||
|
w2v_model.train(
|
||
|
sentence,
|
||
|
total_examples=w2v_model.corpus_count,
|
||
|
epochs=20,
|
||
|
report_delay=1
|
||
|
)
|
||
|
t2 = time()
|
||
|
print('train w2v model cost time: {}s'.format(t2 - t1))
|
||
|
w2v_model.save(model_save_path)
|
||
|
w2v_model.wv.save_word2vec_format(vec_save_path, binary=False)
|
||
|
|
||
|
|
||
|
def get_pretrainde_w2v():
|
||
|
w2v_path = 'w2v_model.txt'
|
||
|
w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False)
|
||
|
|
||
|
word2id_path = 'word2id.json'
|
||
|
id2_word_path = 'id2word.json'
|
||
|
with open(word2id_path, 'r', encoding='utf-8') as f:
|
||
|
word2id = json.load(f)
|
||
|
with open(id2_word_path, 'r', encoding='utf-8') as f:
|
||
|
id2word = json.load(f)
|
||
|
|
||
|
vocab_size = len(word2id)
|
||
|
embedding_size = 300
|
||
|
weight = torch.zeros(vocab_size, embedding_size)
|
||
|
for i in range(len(w2v_model.index2word)):
|
||
|
try:
|
||
|
index = word2id[w2v_model.index2word[i]]
|
||
|
except:
|
||
|
continue
|
||
|
weight[index, :] = torch.from_numpy(w2v_model.get_vector(
|
||
|
id2word[str(word2id[w2v_model.index2word[i]])]))
|
||
|
# print(weight)
|
||
|
return weight
|
||
|
|
||
|
|
||
|
class ModelConfig():
|
||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
vocab_size = 0
|
||
|
num_classes = 3
|
||
|
max_length = 64
|
||
|
use_pretrained_w2v = False
|
||
|
if use_pretrained_w2v:
|
||
|
if not os.path.exists('w2v_model.txt'):
|
||
|
load_w2v()
|
||
|
embedding_pretrained = get_pretrainde_w2v()
|
||
|
else:
|
||
|
embedding_pretrained = None
|
||
|
|
||
|
embedding_size = embedding_pretrained.size(1) if embedding_pretrained is not None else 300
|
||
|
|
||
|
kenel_num = 256
|
||
|
kenel_size = [2,3,4]
|
||
|
lr = 0.001
|
||
|
epochs = 10
|
||
|
batch_size = 128
|
||
|
dropout = 0.5
|
||
|
with open('word2id.json', 'r', encoding='utf-8') as f:
|
||
|
word2id = json.load(f)
|
||
|
vocab_size = len(word2id)
|