You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.8 KiB

3 years ago
import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
def extract_three_cls_data(data_path,save_path):
map_path = 'map.json'
data = pd.read_csv(data_path, sep='\t').dropna()
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
cls_data.index = range(len(cls_data))
print(Counter(cls_data['label']))
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))}
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
if not os.path.exists(label_map_json):
with open(map_path, 'w', encoding='utf-8') as f:
f.write(label_map_json)
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
stopwords = f.readlines()
stopwords = [i.strip() for i in stopwords]
cls_data['text_seg'] = ''
for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
words = row['text'].split(' ')
out_str = ''
for word in words:
if word not in stopwords:
out_str += word
out_str += ' '
cls_data['text_seg'][idx] = out_str
cls_data.to_csv(save_path, index=False)
def build_word2id(lists):
maps = {}
for item in lists:
if item not in maps:
maps[item] = len(maps)
return maps
def build_data(train_data, word2id_map, max_length):
data = train_data['text_seg']
train_list = []
label_list = train_data['textcnn_label']
for line in data:
train_word_list = line.split(' ')
train_line_id = []
for word in train_word_list:
id = word2id_map[word]
train_line_id.append(id)
length = len(train_line_id)
if length > max_length + 1:
train_line_id = train_line_id[:max_length + 1]
if length < max_length + 1:
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
train_list.append(train_line_id)
return train_list, label_list
def filter_stopwords(data_path, save_path):
if not os.path.exists(save_path):
extract_three_cls_data(data_path, save_path)
def concat_all_data(train_path, dev_path, test_path):
data_train = pd.read_csv(train_path)
data_dev = pd.read_csv(dev_path)
data_test = pd.read_csv(test_path)
data = pd.concat([data_train, data_dev, data_test])
data.index = range(len(data))
return data
def gen_word2id(train_path, dev_path, test_path):
data = concat_all_data(train_path, dev_path, test_path)
word2id_path = 'word2id.json'
id2word_path = 'id2word.json'
if not os.path.exists(word2id_path):
data_lines = data['text_seg']
words_list = []
for line in tqdm(data_lines, desc='gen word2id'):
words = line.split(' ')
words_list.extend(words)
word2id = build_word2id(words_list)
word2id['PAD'] = len(word2id)
id2word = {word2id[w]: w for w in word2id}
with open(word2id_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
with open(id2word_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
else:
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
return word2id, id2word
def process_data(data_path, word2id, max_length):
data = pd.read_csv(data_path)
train_list, label_list = build_data(data, word2id, max_length)
return train_list, label_list
def prepare_data(max_length):
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
train_save_path = 'train.csv'
filter_stopwords(train_data_path, train_save_path)
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
dev_save_path = 'dev.csv'
filter_stopwords(dev_data_path, dev_save_path)
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
test_save_path = 'test.csv'
filter_stopwords(test_data_path, test_save_path)
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
X_train, y_train = process_data(train_save_path, word2id, max_length)
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
X_test, y_test = process_data(test_save_path, word2id, max_length)
return X_train,y_train, X_dev, y_dev, X_test, y_test
if __name__ == '__main__':
max_length = 64
prepare_data(max_length)