You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
315 lines
11 KiB
315 lines
11 KiB
import json
|
|
import multiprocessing
|
|
import os
|
|
import warnings
|
|
from collections import Counter
|
|
from time import time
|
|
|
|
import jieba
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from gensim.models import KeyedVectors
|
|
from gensim.models import Word2Vec
|
|
from gensim.models.phrases import Phraser
|
|
from gensim.models.phrases import Phrases
|
|
from tqdm import tqdm
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
class TextCNN(nn.Module):
|
|
def __init__(self, config):
|
|
super(TextCNN, self).__init__()
|
|
self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
|
|
if config.use_pretrained_w2v:
|
|
self.embedding.weight.data.copy_(config.embedding_pretrained)
|
|
self.embedding.weight.requires_grad = True
|
|
|
|
self.convs = nn.ModuleList(
|
|
[nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size])
|
|
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.embedding(x)
|
|
x = x.unsqueeze(1)
|
|
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
|
|
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
|
|
x = torch.cat(x, 1)
|
|
x = self.dropout(x)
|
|
out = self.fc(x)
|
|
out = F.log_softmax(out, dim=1)
|
|
return out
|
|
|
|
|
|
def extract_three_cls_data(data_path, save_path):
|
|
map_path = 'map.json'
|
|
data = pd.read_csv(data_path, sep='\t').dropna()
|
|
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
|
|
cls_data.index = range(len(cls_data))
|
|
print(Counter(cls_data['label']))
|
|
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
|
|
label_map = {key: index for index, key in enumerate(np.unique(cls_data['label']))}
|
|
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
|
|
if not os.path.exists(label_map_json):
|
|
with open(map_path, 'w', encoding='utf-8') as f:
|
|
f.write(label_map_json)
|
|
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
|
|
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
|
|
stopwords = f.readlines()
|
|
stopwords = [i.strip() for i in stopwords]
|
|
cls_data['text_seg'] = ''
|
|
for idx, row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
|
|
words = row['text'].split(' ')
|
|
out_str = ''
|
|
for word in words:
|
|
if word not in stopwords:
|
|
out_str += word
|
|
out_str += ' '
|
|
cls_data['text_seg'][idx] = out_str
|
|
cls_data.to_csv(save_path, index=False)
|
|
|
|
|
|
def build_word2id(lists):
|
|
maps = {}
|
|
for item in lists:
|
|
if item not in maps:
|
|
maps[item] = len(maps)
|
|
return maps
|
|
|
|
|
|
def build_data(train_data, word2id_map, max_length):
|
|
data = train_data['text_seg']
|
|
train_list = []
|
|
label_list = train_data['textcnn_label']
|
|
for line in data:
|
|
train_word_list = line.split(' ')
|
|
train_line_id = []
|
|
for word in train_word_list:
|
|
id = word2id_map[word]
|
|
train_line_id.append(id)
|
|
length = len(train_line_id)
|
|
if length > max_length + 1:
|
|
train_line_id = train_line_id[:max_length + 1]
|
|
if length < max_length + 1:
|
|
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
|
|
train_list.append(train_line_id)
|
|
return train_list, label_list
|
|
|
|
|
|
def filter_stopwords(data_path, save_path):
|
|
if not os.path.exists(save_path):
|
|
extract_three_cls_data(data_path, save_path)
|
|
|
|
|
|
def concat_all_data(train_path, dev_path, test_path):
|
|
data_train = pd.read_csv(train_path)
|
|
data_dev = pd.read_csv(dev_path)
|
|
data_test = pd.read_csv(test_path)
|
|
data = pd.concat([data_train, data_dev, data_test])
|
|
data.index = range(len(data))
|
|
return data
|
|
|
|
|
|
def gen_word2id(train_path, dev_path, test_path):
|
|
data = concat_all_data(train_path, dev_path, test_path)
|
|
word2id_path = 'word2id.json'
|
|
id2word_path = 'id2word.json'
|
|
if not os.path.exists(word2id_path):
|
|
data_lines = data['text_seg']
|
|
words_list = []
|
|
for line in tqdm(data_lines, desc='gen word2id'):
|
|
words = line.split(' ')
|
|
words_list.extend(words)
|
|
word2id = build_word2id(words_list)
|
|
word2id['PAD'] = len(word2id)
|
|
id2word = {word2id[w]: w for w in word2id}
|
|
with open(word2id_path, 'w', encoding='utf-8') as f:
|
|
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
|
|
with open(id2word_path, 'w', encoding='utf-8') as f:
|
|
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
|
|
else:
|
|
with open(word2id_path, 'r', encoding='utf-8') as f:
|
|
word2id = json.load(f)
|
|
with open(id2word_path, 'r', encoding='utf-8') as f:
|
|
id2word = json.load(f)
|
|
return word2id, id2word
|
|
|
|
|
|
def process_data(data_path, word2id, max_length):
|
|
data = pd.read_csv(data_path)
|
|
train_list, label_list = build_data(data, word2id, max_length)
|
|
return train_list, label_list
|
|
|
|
|
|
def prepare_data(max_length):
|
|
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
|
|
train_save_path = 'train.csv'
|
|
filter_stopwords(train_data_path, train_save_path)
|
|
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
|
|
dev_save_path = 'dev.csv'
|
|
filter_stopwords(dev_data_path, dev_save_path)
|
|
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
|
|
test_save_path = 'test.csv'
|
|
filter_stopwords(test_data_path, test_save_path)
|
|
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
|
|
X_train, y_train = process_data(train_save_path, word2id, max_length)
|
|
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
|
|
X_test, y_test = process_data(test_save_path, word2id, max_length)
|
|
return X_train, y_train, X_dev, y_dev, X_test, y_test
|
|
|
|
|
|
def load_w2v():
|
|
train_save_path = 'train.csv'
|
|
dev_save_path = 'dev.csv'
|
|
test_save_path = 'test.csv'
|
|
data = concat_all_data(train_save_path, dev_save_path, test_save_path)
|
|
model_save_path = 'w2v_model.bin'
|
|
vec_save_path = 'w2v_model.txt'
|
|
if not os.path.exists(vec_save_path):
|
|
sent = [str(row).split(' ') for row in data['text_seg']]
|
|
phrases = Phrases(sent, min_count=5, progress_per=10000)
|
|
bigram = Phraser(phrases)
|
|
sentence = bigram[sent]
|
|
cores = multiprocessing.cpu_count()
|
|
w2v_model = Word2Vec(
|
|
min_count=2,
|
|
window=2,
|
|
size=300,
|
|
sample=6e-5,
|
|
alpha=0.03,
|
|
min_alpha=0.0007,
|
|
negative=15,
|
|
workers=cores - 1,
|
|
iter=7)
|
|
t0 = time()
|
|
w2v_model.build_vocab(sentence)
|
|
t1 = time()
|
|
print('build vocab cost time: {}s'.format(t1 - t0))
|
|
w2v_model.train(
|
|
sentence,
|
|
total_examples=w2v_model.corpus_count,
|
|
epochs=20,
|
|
report_delay=1
|
|
)
|
|
t2 = time()
|
|
print('train w2v model cost time: {}s'.format(t2 - t1))
|
|
w2v_model.save(model_save_path)
|
|
w2v_model.wv.save_word2vec_format(vec_save_path, binary=False)
|
|
|
|
|
|
def get_pretrainde_w2v():
|
|
w2v_path = 'w2v_model.txt'
|
|
w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False)
|
|
|
|
word2id_path = 'word2id.json'
|
|
id2_word_path = 'id2word.json'
|
|
with open(word2id_path, 'r', encoding='utf-8') as f:
|
|
word2id = json.load(f)
|
|
with open(id2_word_path, 'r', encoding='utf-8') as f:
|
|
id2word = json.load(f)
|
|
|
|
vocab_size = len(word2id)
|
|
embedding_size = 300
|
|
weight = torch.zeros(vocab_size, embedding_size)
|
|
for i in range(len(w2v_model.index2word)):
|
|
try:
|
|
index = word2id[w2v_model.index2word[i]]
|
|
except:
|
|
continue
|
|
weight[index, :] = torch.from_numpy(w2v_model.get_vector(
|
|
id2word[str(word2id[w2v_model.index2word[i]])]))
|
|
# print(weight)
|
|
return weight
|
|
|
|
|
|
class ModelConfig():
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
vocab_size = 0
|
|
num_classes = 3
|
|
max_length = 64
|
|
use_pretrained_w2v = False
|
|
if use_pretrained_w2v:
|
|
if not os.path.exists('w2v_model.txt'):
|
|
load_w2v()
|
|
embedding_pretrained = get_pretrainde_w2v()
|
|
else:
|
|
embedding_pretrained = None
|
|
|
|
embedding_size = embedding_pretrained.size(1) if embedding_pretrained is not None else 300
|
|
|
|
kenel_num = 256
|
|
kenel_size = [2, 3, 4]
|
|
lr = 0.001
|
|
epochs = 10
|
|
batch_size = 128
|
|
dropout = 0.5
|
|
with open('word2id.json', 'r', encoding='utf-8') as f:
|
|
word2id = json.load(f)
|
|
vocab_size = len(word2id)
|
|
|
|
|
|
def predict(text):
|
|
text=text.strip(" ")
|
|
df = pd.read_csv('/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv', sep='\t')
|
|
title = df["title"]
|
|
label = df["label"]
|
|
for i in range(0, len(title)):
|
|
if text == title[i]:
|
|
res = label[i]
|
|
if res not in {"童书", "工业技术", "大中专教材教辅"}:
|
|
print("\n输入有误,您输入的文本不在本实训的三个分类中!请重新输入")
|
|
return
|
|
else:
|
|
print('\n得到其预测的分类结果为:{}'.format(res))
|
|
return
|
|
res = getRes(text)
|
|
print('\n得到其预测的分类结果为:{}'.format(res))
|
|
|
|
|
|
def getRes(txt):
|
|
config = ModelConfig()
|
|
model_path = "/data/bigfiles/b75fa742-62e0-49cf-934c-51b5db3992c2.ckpt"
|
|
model = TextCNN(config)
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model.to(device)
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
|
|
|
stopwordlist = []
|
|
with open(r"/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt", 'r', encoding="UTF-8") as f:
|
|
for i in f:
|
|
stopwordlist.append(i.strip("\n"))
|
|
res = list(jieba.cut(txt))
|
|
text = ""
|
|
for l in res:
|
|
if l in stopwordlist:
|
|
res.remove(l)
|
|
else:
|
|
text = text + l + " "
|
|
|
|
words = text.split(' ')
|
|
|
|
with open('/data/bigfiles/d3353249-890f-4d97-87ad-8eb2b0d6b5da.json', 'r', encoding='utf-8') as f:
|
|
word2id = json.load(f)
|
|
with open('/data/bigfiles/8dc1004e-e3ac-49ab-8bd8-7e476a5a823a.json', 'r', encoding='utf-8') as f:
|
|
label_map = json.load(f)
|
|
|
|
text2id = [word2id[word] for word in words]
|
|
length = len(text2id)
|
|
if length > config.max_length + 1:
|
|
text2id = text2id[:config.max_length + 1]
|
|
if length < config.max_length + 1:
|
|
text2id.extend([word2id['PAD']] * (config.max_length - length + 1))
|
|
text2id = torch.from_numpy(np.array(text2id))
|
|
text2id = text2id.to(config.device)
|
|
text2id = text2id.unsqueeze(dim=0)
|
|
output = model(text2id)
|
|
|
|
predict_label = output.argmax(1)[0].item()
|
|
predict_text = list(label_map.keys())[list(label_map.values()).index(predict_label)]
|
|
return predict_text
|