You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

315 lines
11 KiB

import json
import multiprocessing
import os
import warnings
from collections import Counter
from time import time
import jieba
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from gensim.models import KeyedVectors
from gensim.models import Word2Vec
from gensim.models.phrases import Phraser
from gensim.models.phrases import Phrases
from tqdm import tqdm
warnings.filterwarnings("ignore")
class TextCNN(nn.Module):
def __init__(self, config):
super(TextCNN, self).__init__()
self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
if config.use_pretrained_w2v:
self.embedding.weight.data.copy_(config.embedding_pretrained)
self.embedding.weight.requires_grad = True
self.convs = nn.ModuleList(
[nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size])
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes)
def forward(self, x):
x = self.embedding(x)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
x = torch.cat(x, 1)
x = self.dropout(x)
out = self.fc(x)
out = F.log_softmax(out, dim=1)
return out
def extract_three_cls_data(data_path, save_path):
map_path = 'map.json'
data = pd.read_csv(data_path, sep='\t').dropna()
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
cls_data.index = range(len(cls_data))
print(Counter(cls_data['label']))
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
label_map = {key: index for index, key in enumerate(np.unique(cls_data['label']))}
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
if not os.path.exists(label_map_json):
with open(map_path, 'w', encoding='utf-8') as f:
f.write(label_map_json)
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
stopwords = f.readlines()
stopwords = [i.strip() for i in stopwords]
cls_data['text_seg'] = ''
for idx, row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
words = row['text'].split(' ')
out_str = ''
for word in words:
if word not in stopwords:
out_str += word
out_str += ' '
cls_data['text_seg'][idx] = out_str
cls_data.to_csv(save_path, index=False)
def build_word2id(lists):
maps = {}
for item in lists:
if item not in maps:
maps[item] = len(maps)
return maps
def build_data(train_data, word2id_map, max_length):
data = train_data['text_seg']
train_list = []
label_list = train_data['textcnn_label']
for line in data:
train_word_list = line.split(' ')
train_line_id = []
for word in train_word_list:
id = word2id_map[word]
train_line_id.append(id)
length = len(train_line_id)
if length > max_length + 1:
train_line_id = train_line_id[:max_length + 1]
if length < max_length + 1:
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
train_list.append(train_line_id)
return train_list, label_list
def filter_stopwords(data_path, save_path):
if not os.path.exists(save_path):
extract_three_cls_data(data_path, save_path)
def concat_all_data(train_path, dev_path, test_path):
data_train = pd.read_csv(train_path)
data_dev = pd.read_csv(dev_path)
data_test = pd.read_csv(test_path)
data = pd.concat([data_train, data_dev, data_test])
data.index = range(len(data))
return data
def gen_word2id(train_path, dev_path, test_path):
data = concat_all_data(train_path, dev_path, test_path)
word2id_path = 'word2id.json'
id2word_path = 'id2word.json'
if not os.path.exists(word2id_path):
data_lines = data['text_seg']
words_list = []
for line in tqdm(data_lines, desc='gen word2id'):
words = line.split(' ')
words_list.extend(words)
word2id = build_word2id(words_list)
word2id['PAD'] = len(word2id)
id2word = {word2id[w]: w for w in word2id}
with open(word2id_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
with open(id2word_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
else:
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
return word2id, id2word
def process_data(data_path, word2id, max_length):
data = pd.read_csv(data_path)
train_list, label_list = build_data(data, word2id, max_length)
return train_list, label_list
def prepare_data(max_length):
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
train_save_path = 'train.csv'
filter_stopwords(train_data_path, train_save_path)
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
dev_save_path = 'dev.csv'
filter_stopwords(dev_data_path, dev_save_path)
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
test_save_path = 'test.csv'
filter_stopwords(test_data_path, test_save_path)
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
X_train, y_train = process_data(train_save_path, word2id, max_length)
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
X_test, y_test = process_data(test_save_path, word2id, max_length)
return X_train, y_train, X_dev, y_dev, X_test, y_test
def load_w2v():
train_save_path = 'train.csv'
dev_save_path = 'dev.csv'
test_save_path = 'test.csv'
data = concat_all_data(train_save_path, dev_save_path, test_save_path)
model_save_path = 'w2v_model.bin'
vec_save_path = 'w2v_model.txt'
if not os.path.exists(vec_save_path):
sent = [str(row).split(' ') for row in data['text_seg']]
phrases = Phrases(sent, min_count=5, progress_per=10000)
bigram = Phraser(phrases)
sentence = bigram[sent]
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(
min_count=2,
window=2,
size=300,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores - 1,
iter=7)
t0 = time()
w2v_model.build_vocab(sentence)
t1 = time()
print('build vocab cost time: {}s'.format(t1 - t0))
w2v_model.train(
sentence,
total_examples=w2v_model.corpus_count,
epochs=20,
report_delay=1
)
t2 = time()
print('train w2v model cost time: {}s'.format(t2 - t1))
w2v_model.save(model_save_path)
w2v_model.wv.save_word2vec_format(vec_save_path, binary=False)
def get_pretrainde_w2v():
w2v_path = 'w2v_model.txt'
w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False)
word2id_path = 'word2id.json'
id2_word_path = 'id2word.json'
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2_word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
vocab_size = len(word2id)
embedding_size = 300
weight = torch.zeros(vocab_size, embedding_size)
for i in range(len(w2v_model.index2word)):
try:
index = word2id[w2v_model.index2word[i]]
except:
continue
weight[index, :] = torch.from_numpy(w2v_model.get_vector(
id2word[str(word2id[w2v_model.index2word[i]])]))
# print(weight)
return weight
class ModelConfig():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = 0
num_classes = 3
max_length = 64
use_pretrained_w2v = False
if use_pretrained_w2v:
if not os.path.exists('w2v_model.txt'):
load_w2v()
embedding_pretrained = get_pretrainde_w2v()
else:
embedding_pretrained = None
embedding_size = embedding_pretrained.size(1) if embedding_pretrained is not None else 300
kenel_num = 256
kenel_size = [2, 3, 4]
lr = 0.001
epochs = 10
batch_size = 128
dropout = 0.5
with open('word2id.json', 'r', encoding='utf-8') as f:
word2id = json.load(f)
vocab_size = len(word2id)
def predict(text):
text=text.strip(" ")
df = pd.read_csv('/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv', sep='\t')
title = df["title"]
label = df["label"]
for i in range(0, len(title)):
if text == title[i]:
res = label[i]
if res not in {"童书", "工业技术", "大中专教材教辅"}:
print("\n输入有误,您输入的文本不在本实训的三个分类中!请重新输入")
return
else:
print('\n得到其预测的分类结果为:{}'.format(res))
return
res = getRes(text)
print('\n得到其预测的分类结果为:{}'.format(res))
def getRes(txt):
config = ModelConfig()
model_path = "/data/bigfiles/b75fa742-62e0-49cf-934c-51b5db3992c2.ckpt"
model = TextCNN(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
stopwordlist = []
with open(r"/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt", 'r', encoding="UTF-8") as f:
for i in f:
stopwordlist.append(i.strip("\n"))
res = list(jieba.cut(txt))
text = ""
for l in res:
if l in stopwordlist:
res.remove(l)
else:
text = text + l + " "
words = text.split(' ')
with open('/data/bigfiles/d3353249-890f-4d97-87ad-8eb2b0d6b5da.json', 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open('/data/bigfiles/8dc1004e-e3ac-49ab-8bd8-7e476a5a823a.json', 'r', encoding='utf-8') as f:
label_map = json.load(f)
text2id = [word2id[word] for word in words]
length = len(text2id)
if length > config.max_length + 1:
text2id = text2id[:config.max_length + 1]
if length < config.max_length + 1:
text2id.extend([word2id['PAD']] * (config.max_length - length + 1))
text2id = torch.from_numpy(np.array(text2id))
text2id = text2id.to(config.device)
text2id = text2id.unsqueeze(dim=0)
output = model(text2id)
predict_label = output.argmax(1)[0].item()
predict_text = list(label_map.keys())[list(label_map.values()).index(predict_label)]
return predict_text