import json import multiprocessing import os import warnings from collections import Counter from time import time import jieba import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from gensim.models import KeyedVectors from gensim.models import Word2Vec from gensim.models.phrases import Phraser from gensim.models.phrases import Phrases from tqdm import tqdm warnings.filterwarnings("ignore") class TextCNN(nn.Module): def __init__(self, config): super(TextCNN, self).__init__() self.embedding = nn.Embedding(config.vocab_size, config.embedding_size) if config.use_pretrained_w2v: self.embedding.weight.data.copy_(config.embedding_pretrained) self.embedding.weight.requires_grad = True self.convs = nn.ModuleList( [nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size]) self.dropout = nn.Dropout(config.dropout) self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes) def forward(self, x): x = self.embedding(x) x = x.unsqueeze(1) x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x] x = torch.cat(x, 1) x = self.dropout(x) out = self.fc(x) out = F.log_softmax(out, dim=1) return out def extract_three_cls_data(data_path, save_path): map_path = 'map.json' data = pd.read_csv(data_path, sep='\t').dropna() cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')] cls_data.index = range(len(cls_data)) print(Counter(cls_data['label'])) print('总共 {} 个类别'.format(len(np.unique(cls_data['label'])))) label_map = {key: index for index, key in enumerate(np.unique(cls_data['label']))} label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3) if not os.path.exists(label_map_json): with open(map_path, 'w', encoding='utf-8') as f: f.write(label_map_json) cls_data['textcnn_label'] = cls_data['label'].map(label_map) with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f: stopwords = f.readlines() stopwords = [i.strip() for i in stopwords] cls_data['text_seg'] = '' for idx, row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)): words = row['text'].split(' ') out_str = '' for word in words: if word not in stopwords: out_str += word out_str += ' ' cls_data['text_seg'][idx] = out_str cls_data.to_csv(save_path, index=False) def build_word2id(lists): maps = {} for item in lists: if item not in maps: maps[item] = len(maps) return maps def build_data(train_data, word2id_map, max_length): data = train_data['text_seg'] train_list = [] label_list = train_data['textcnn_label'] for line in data: train_word_list = line.split(' ') train_line_id = [] for word in train_word_list: id = word2id_map[word] train_line_id.append(id) length = len(train_line_id) if length > max_length + 1: train_line_id = train_line_id[:max_length + 1] if length < max_length + 1: train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1)) train_list.append(train_line_id) return train_list, label_list def filter_stopwords(data_path, save_path): if not os.path.exists(save_path): extract_three_cls_data(data_path, save_path) def concat_all_data(train_path, dev_path, test_path): data_train = pd.read_csv(train_path) data_dev = pd.read_csv(dev_path) data_test = pd.read_csv(test_path) data = pd.concat([data_train, data_dev, data_test]) data.index = range(len(data)) return data def gen_word2id(train_path, dev_path, test_path): data = concat_all_data(train_path, dev_path, test_path) word2id_path = 'word2id.json' id2word_path = 'id2word.json' if not os.path.exists(word2id_path): data_lines = data['text_seg'] words_list = [] for line in tqdm(data_lines, desc='gen word2id'): words = line.split(' ') words_list.extend(words) word2id = build_word2id(words_list) word2id['PAD'] = len(word2id) id2word = {word2id[w]: w for w in word2id} with open(word2id_path, 'w', encoding='utf-8') as f: f.write(json.dumps(word2id, ensure_ascii=False, indent=2)) with open(id2word_path, 'w', encoding='utf-8') as f: f.write(json.dumps(id2word, ensure_ascii=False, indent=2)) else: with open(word2id_path, 'r', encoding='utf-8') as f: word2id = json.load(f) with open(id2word_path, 'r', encoding='utf-8') as f: id2word = json.load(f) return word2id, id2word def process_data(data_path, word2id, max_length): data = pd.read_csv(data_path) train_list, label_list = build_data(data, word2id, max_length) return train_list, label_list def prepare_data(max_length): train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv' train_save_path = 'train.csv' filter_stopwords(train_data_path, train_save_path) dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv' dev_save_path = 'dev.csv' filter_stopwords(dev_data_path, dev_save_path) test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv' test_save_path = 'test.csv' filter_stopwords(test_data_path, test_save_path) word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path) X_train, y_train = process_data(train_save_path, word2id, max_length) X_dev, y_dev = process_data(dev_save_path, word2id, max_length) X_test, y_test = process_data(test_save_path, word2id, max_length) return X_train, y_train, X_dev, y_dev, X_test, y_test def load_w2v(): train_save_path = 'train.csv' dev_save_path = 'dev.csv' test_save_path = 'test.csv' data = concat_all_data(train_save_path, dev_save_path, test_save_path) model_save_path = 'w2v_model.bin' vec_save_path = 'w2v_model.txt' if not os.path.exists(vec_save_path): sent = [str(row).split(' ') for row in data['text_seg']] phrases = Phrases(sent, min_count=5, progress_per=10000) bigram = Phraser(phrases) sentence = bigram[sent] cores = multiprocessing.cpu_count() w2v_model = Word2Vec( min_count=2, window=2, size=300, sample=6e-5, alpha=0.03, min_alpha=0.0007, negative=15, workers=cores - 1, iter=7) t0 = time() w2v_model.build_vocab(sentence) t1 = time() print('build vocab cost time: {}s'.format(t1 - t0)) w2v_model.train( sentence, total_examples=w2v_model.corpus_count, epochs=20, report_delay=1 ) t2 = time() print('train w2v model cost time: {}s'.format(t2 - t1)) w2v_model.save(model_save_path) w2v_model.wv.save_word2vec_format(vec_save_path, binary=False) def get_pretrainde_w2v(): w2v_path = 'w2v_model.txt' w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False) word2id_path = 'word2id.json' id2_word_path = 'id2word.json' with open(word2id_path, 'r', encoding='utf-8') as f: word2id = json.load(f) with open(id2_word_path, 'r', encoding='utf-8') as f: id2word = json.load(f) vocab_size = len(word2id) embedding_size = 300 weight = torch.zeros(vocab_size, embedding_size) for i in range(len(w2v_model.index2word)): try: index = word2id[w2v_model.index2word[i]] except: continue weight[index, :] = torch.from_numpy(w2v_model.get_vector( id2word[str(word2id[w2v_model.index2word[i]])])) # print(weight) return weight class ModelConfig(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vocab_size = 0 num_classes = 3 max_length = 64 use_pretrained_w2v = False if use_pretrained_w2v: if not os.path.exists('w2v_model.txt'): load_w2v() embedding_pretrained = get_pretrainde_w2v() else: embedding_pretrained = None embedding_size = embedding_pretrained.size(1) if embedding_pretrained is not None else 300 kenel_num = 256 kenel_size = [2, 3, 4] lr = 0.001 epochs = 10 batch_size = 128 dropout = 0.5 with open('word2id.json', 'r', encoding='utf-8') as f: word2id = json.load(f) vocab_size = len(word2id) def predict(text): text=text.strip(" ") df = pd.read_csv('/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv', sep='\t') title = df["title"] label = df["label"] for i in range(0, len(title)): if text == title[i]: res = label[i] if res not in {"童书", "工业技术", "大中专教材教辅"}: print("\n输入有误,您输入的文本不在本实训的三个分类中!请重新输入") return else: print('\n得到其预测的分类结果为:{}'.format(res)) return res = getRes(text) print('\n得到其预测的分类结果为:{}'.format(res)) def getRes(txt): config = ModelConfig() model_path = "/data/bigfiles/b75fa742-62e0-49cf-934c-51b5db3992c2.ckpt" model = TextCNN(config) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.load_state_dict(torch.load(model_path, map_location='cpu')) stopwordlist = [] with open(r"/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt", 'r', encoding="UTF-8") as f: for i in f: stopwordlist.append(i.strip("\n")) res = list(jieba.cut(txt)) text = "" for l in res: if l in stopwordlist: res.remove(l) else: text = text + l + " " words = text.split(' ') with open('/data/bigfiles/d3353249-890f-4d97-87ad-8eb2b0d6b5da.json', 'r', encoding='utf-8') as f: word2id = json.load(f) with open('/data/bigfiles/8dc1004e-e3ac-49ab-8bd8-7e476a5a823a.json', 'r', encoding='utf-8') as f: label_map = json.load(f) text2id = [word2id[word] for word in words] length = len(text2id) if length > config.max_length + 1: text2id = text2id[:config.max_length + 1] if length < config.max_length + 1: text2id.extend([word2id['PAD']] * (config.max_length - length + 1)) text2id = torch.from_numpy(np.array(text2id)) text2id = text2id.to(config.device) text2id = text2id.unsqueeze(dim=0) output = model(text2id) predict_label = output.argmax(1)[0].item() predict_text = list(label_map.keys())[list(label_map.values()).index(predict_label)] return predict_text