You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
5.2 KiB
137 lines
5.2 KiB
import json
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from torch.optim.lr_scheduler import StepLR
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
best_acc_list = []
|
|
|
|
class TextCNN(nn.Module):
|
|
def __init__(self, config):
|
|
super(TextCNN, self).__init__()
|
|
self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
|
|
if config.use_pretrained_w2v:
|
|
self.embedding.weight.data.copy_(config.embedding_pretrained)
|
|
self.embedding.weight.requires_grad = True
|
|
|
|
self.convs = nn.ModuleList(
|
|
[nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size])
|
|
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.embedding(x)
|
|
x = x.unsqueeze(1)
|
|
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
|
|
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
|
|
x = torch.cat(x, 1)
|
|
x = self.dropout(x)
|
|
out = self.fc(x)
|
|
out = F.log_softmax(out, dim=1)
|
|
return out
|
|
|
|
|
|
|
|
class TextCNN_DataLoader(Dataset):
|
|
def __init__(self, train_data, labels):
|
|
self.train_data = train_data
|
|
self.labels = labels
|
|
|
|
def __len__(self):
|
|
return len(self.train_data)
|
|
|
|
def __getitem__(self, index):
|
|
sentence = np.array(self.train_data[index])
|
|
label = self.labels[index]
|
|
|
|
return sentence, label
|
|
|
|
|
|
class TextCNNHelp():
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.model = TextCNN(self.config)
|
|
self.model.to(self.config.device)
|
|
self.model_path = '/data/bigfiles/a4530bdd-0f97-44dd-ac9d-6b3da23a8acb.ckpt'
|
|
|
|
def load_data(self, X_train, y_train, X_dev, y_dev, X_test, y_test):
|
|
self.train_loader = DataLoader(TextCNN_DataLoader(X_train, y_train), batch_size=self.config.batch_size)
|
|
self.dev_loader = DataLoader(TextCNN_DataLoader(X_dev, y_dev), batch_size=self.config.batch_size)
|
|
self.test_loader = DataLoader(TextCNN_DataLoader(X_test, y_test), batch_size=self.config.batch_size)
|
|
self.optermizer = torch.optim.SGD(self.model.parameters(), lr=self.config.lr)
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
self.scheduler = StepLR(self.optermizer, step_size=5)
|
|
|
|
def train_epoch(self, epoch):
|
|
self.model.train()
|
|
count = 0
|
|
correct = 0
|
|
loss_sum = 0
|
|
for i, (sentence, label) in enumerate(self.train_loader):
|
|
self.optermizer.zero_grad()
|
|
sentence = sentence.type(torch.LongTensor).to(self.config.device)
|
|
label = label.type(torch.LongTensor).to(self.config.device)
|
|
out = self.model(sentence)
|
|
# print('out: {}'.format(out.argmax(1)))
|
|
loss = self.criterion(out, label)
|
|
loss_sum += loss.item()
|
|
count += len(sentence)
|
|
correct += (out.argmax(1) == label).float().sum().item()
|
|
# print('correct / count: {}'.format(correct/ count))
|
|
if count % 200 == 0:
|
|
print('train epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, loss_sum / 100))
|
|
loss_sum = 0
|
|
loss.backward()
|
|
self.optermizer.step()
|
|
print('train epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count)))
|
|
self.scheduler.step()
|
|
torch.save(self.model.state_dict(), './{}.ckpt'.format(epoch))
|
|
|
|
def validation(self, epoch):
|
|
self.model.eval()
|
|
count, correct = 0, 0
|
|
val_loss_sum = 0
|
|
for i, (sentence, label) in enumerate(self.dev_loader):
|
|
sentence, label = sentence.to(self.config.device), label.to(self.config.device)
|
|
output = self.model(sentence)
|
|
loss = self.criterion(output, label)
|
|
val_loss_sum += loss.item()
|
|
correct += (output.argmax(1) == label).float().sum().item()
|
|
count += len(sentence)
|
|
if count % 200 == 0:
|
|
print('eval epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, val_loss_sum / 100))
|
|
val_loss_sum = 0
|
|
print('eval epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count)))
|
|
best_acc_list.append(100 * (correct / count))
|
|
|
|
def test(self):
|
|
model = TextCNN(self.config)
|
|
model.to(self.config.device)
|
|
model.load_state_dict(torch.load(self.model_path,map_location='cpu'))
|
|
correct = 0
|
|
count = 0
|
|
for i, (sentence, label) in enumerate(self.test_loader):
|
|
sentence, label = sentence.to(self.config.device), label.to(self.config.device)
|
|
output = model(sentence)
|
|
count += len(sentence)
|
|
correct += (output.argmax(1) == label).float().sum().item()
|
|
print('test acc: {}%'.format(100 * (correct / count)))
|
|
|
|
def train_model(self):
|
|
print('开始训练:')
|
|
epochs = self.config.epochs
|
|
for i in range(1, epochs + 1):
|
|
self.train_epoch(i)
|
|
self.validation(i)
|
|
model_path = self.model_path
|
|
print('开始测试:')
|
|
self.test()
|
|
|