You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.8 KiB

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
def extract_three_cls_data(data_path,save_path):
map_path = 'map.json'
data = pd.read_csv(data_path, sep='\t').dropna()
cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
cls_data.index = range(len(cls_data))
print(Counter(cls_data['label']))
print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))}
label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
if not os.path.exists(label_map_json):
with open(map_path, 'w', encoding='utf-8') as f:
f.write(label_map_json)
cls_data['textcnn_label'] = cls_data['label'].map(label_map)
with open('/data/bigfiles/52afd73b-7526-4120-b196-4d468218114a.txt', 'r', encoding='utf-8') as f:
stopwords = f.readlines()
stopwords = [i.strip() for i in stopwords]
cls_data['text_seg'] = ''
for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
words = row['text'].split(' ')
out_str = ''
for word in words:
if word not in stopwords:
out_str += word
out_str += ' '
cls_data['text_seg'][idx] = out_str
cls_data.to_csv(save_path, index=False)
def build_word2id(lists):
maps = {}
for item in lists:
if item not in maps:
maps[item] = len(maps)
return maps
def build_data(train_data, word2id_map, max_length):
data = train_data['text_seg']
train_list = []
label_list = train_data['textcnn_label']
for line in data:
train_word_list = line.split(' ')
train_line_id = []
for word in train_word_list:
id = word2id_map[word]
train_line_id.append(id)
length = len(train_line_id)
if length > max_length + 1:
train_line_id = train_line_id[:max_length + 1]
if length < max_length + 1:
train_line_id.extend([word2id_map['PAD']] * (max_length - length + 1))
train_list.append(train_line_id)
return train_list, label_list
def filter_stopwords(data_path, save_path):
if not os.path.exists(save_path):
extract_three_cls_data(data_path, save_path)
def concat_all_data(train_path, dev_path, test_path):
data_train = pd.read_csv(train_path)
data_dev = pd.read_csv(dev_path)
data_test = pd.read_csv(test_path)
data = pd.concat([data_train, data_dev, data_test])
data.index = range(len(data))
return data
def gen_word2id(train_path, dev_path, test_path):
data = concat_all_data(train_path, dev_path, test_path)
word2id_path = 'word2id.json'
id2word_path = 'id2word.json'
if not os.path.exists(word2id_path):
data_lines = data['text_seg']
words_list = []
for line in tqdm(data_lines, desc='gen word2id'):
words = line.split(' ')
words_list.extend(words)
word2id = build_word2id(words_list)
word2id['PAD'] = len(word2id)
id2word = {word2id[w]: w for w in word2id}
with open(word2id_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(word2id, ensure_ascii=False, indent=2))
with open(id2word_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(id2word, ensure_ascii=False, indent=2))
else:
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
return word2id, id2word
def process_data(data_path, word2id, max_length):
data = pd.read_csv(data_path)
train_list, label_list = build_data(data, word2id, max_length)
return train_list, label_list
def prepare_data(max_length):
train_data_path = '/data/bigfiles/caa7144d-40ea-4b67-91b4-ca4bf2b4c04d.csv'
train_save_path = 'train.csv'
filter_stopwords(train_data_path, train_save_path)
dev_data_path = '/data/bigfiles/15bf9d03-0569-4c87-88d9-f6c3b8deaf2a.csv'
dev_save_path = 'dev.csv'
filter_stopwords(dev_data_path, dev_save_path)
test_data_path = '/data/bigfiles/6d0de77c-d0fd-4d15-84a0-99cfe53790b6.csv'
test_save_path = 'test.csv'
filter_stopwords(test_data_path, test_save_path)
word2id, id2word = gen_word2id(train_save_path, dev_save_path, test_save_path)
X_train, y_train = process_data(train_save_path, word2id, max_length)
X_dev, y_dev = process_data(dev_save_path, word2id, max_length)
X_test, y_test = process_data(test_save_path, word2id, max_length)
return X_train,y_train, X_dev, y_dev, X_test, y_test
if __name__ == '__main__':
max_length = 64
prepare_data(max_length)