You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.8 KiB

2 years ago
import os
import json
import torch
import os
import gensim
import torch
import json
import multiprocessing
import pandas as pd
import numpy as np
from time import time
from gensim.models import Word2Vec
from gensim.models.phrases import Phraser
from gensim.models.phrases import Phrases
from gensim.models import KeyedVectors
import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
def extract_three_cls_data(data_path,save_path):
map_path = 'map.json'
data = pd.read_csv(data_path, sep='\t').dropna()
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
cls_data.index = range(len(cls_data))
print(Counter(cls_data['label']))
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))}
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
if not os.path.exists(label_map_json):
with open(map_path, 'w', encoding='utf-8') as f:
f.write(label_map_json)
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
stopwords = f.readlines()
stopwords = [i.strip() for i in stopwords]
cls_data['text_seg'] = ''
for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
words = row['text'].split(' ')
out_str = ''
for word in words:
if word not in stopwords:
out_str += word
out_str += ' '
cls_data['text_seg'][idx] = out_str
cls_data.to_csv(save_path, index=False)
def build_word2id(lists):
maps = {}
for item in lists:
if item not in maps:
maps[item] = len(maps)
return maps
def build_data(train_data, word2id_map, max_length):
data = train_data['text_seg']
train_list = []
label_list = train_data['textcnn_label']
for line in data:
train_word_list = line.split(' ')
train_line_id = []
for word in train_word_list:
id = word2id_map[word]
train_line_id.append(id)
length = len(train_line_id)
if length > max_length + 1:
train_line_id = train_line_id[:max_length + 1]
if length < max_length + 1:
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
train_list.append(train_line_id)
return train_list, label_list
def filter_stopwords(data_path, save_path):
if not os.path.exists(save_path):
extract_three_cls_data(data_path, save_path)
def concat_all_data(train_path, dev_path, test_path):
data_train = pd.read_csv(train_path)
data_dev = pd.read_csv(dev_path)
data_test = pd.read_csv(test_path)
data = pd.concat([data_train, data_dev, data_test])
data.index = range(len(data))
return data
def gen_word2id(train_path, dev_path, test_path):
data = concat_all_data(train_path, dev_path, test_path)
word2id_path = 'word2id.json'
id2word_path = 'id2word.json'
if not os.path.exists(word2id_path):
data_lines = data['text_seg']
words_list = []
for line in tqdm(data_lines, desc='gen word2id'):
words = line.split(' ')
words_list.extend(words)
word2id = build_word2id(words_list)
word2id['PAD'] = len(word2id)
id2word = {word2id[w]: w for w in word2id}
with open(word2id_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
with open(id2word_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
else:
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
return word2id, id2word
def process_data(data_path, word2id, max_length):
data = pd.read_csv(data_path)
train_list, label_list = build_data(data, word2id, max_length)
return train_list, label_list
def prepare_data(max_length):
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
train_save_path = 'train.csv'
filter_stopwords(train_data_path, train_save_path)
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
dev_save_path = 'dev.csv'
filter_stopwords(dev_data_path, dev_save_path)
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
test_save_path = 'test.csv'
filter_stopwords(test_data_path, test_save_path)
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
X_train, y_train = process_data(train_save_path, word2id, max_length)
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
X_test, y_test = process_data(test_save_path, word2id, max_length)
return X_train,y_train, X_dev, y_dev, X_test, y_test
def load_w2v():
train_save_path = 'train.csv'
dev_save_path = 'dev.csv'
test_save_path = 'test.csv'
data = concat_all_data(train_save_path, dev_save_path, test_save_path)
model_save_path = 'w2v_model.bin'
vec_save_path = 'w2v_model.txt'
if not os.path.exists(vec_save_path):
sent = [str(row).split(' ') for row in data['text_seg']]
phrases = Phrases(sent, min_count=5, progress_per=10000)
bigram = Phraser(phrases)
sentence = bigram[sent]
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(
min_count=2,
window=2,
size=300,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores - 1,
iter=7)
t0 = time()
w2v_model.build_vocab(sentence)
t1 = time()
print('build vocab cost time: {}s'.format(t1 - t0))
w2v_model.train(
sentence,
total_examples=w2v_model.corpus_count,
epochs=20,
report_delay=1
)
t2 = time()
print('train w2v model cost time: {}s'.format(t2 - t1))
w2v_model.save(model_save_path)
w2v_model.wv.save_word2vec_format(vec_save_path, binary=False)
def get_pretrainde_w2v():
w2v_path = 'w2v_model.txt'
w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False)
word2id_path = 'word2id.json'
id2_word_path = 'id2word.json'
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2_word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
vocab_size = len(word2id)
embedding_size = 300
weight = torch.zeros(vocab_size, embedding_size)
for i in range(len(w2v_model.index2word)):
try:
index = word2id[w2v_model.index2word[i]]
except:
continue
weight[index, :] = torch.from_numpy(w2v_model.get_vector(
id2word[str(word2id[w2v_model.index2word[i]])]))
# print(weight)
return weight
class ModelConfig():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = 0
num_classes = 3
max_length = 64
use_pretrained_w2v = False
if use_pretrained_w2v:
if not os.path.exists('w2v_model.txt'):
load_w2v()
embedding_pretrained = get_pretrainde_w2v()
else:
embedding_pretrained = None
embedding_size = embedding_pretrained.size(1) if embedding_pretrained is not None else 300
kenel_num = 256
kenel_size = [2,3,4]
lr = 0.001
epochs = 10
batch_size = 128
dropout = 0.5
with open('word2id.json', 'r', encoding='utf-8') as f:
word2id = json.load(f)
vocab_size = len(word2id)