You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
5.2 KiB

import json
import torch
import torch.nn as nn
import numpy as np
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
best_acc_list = []
class TextCNN(nn.Module):
def __init__(self, config):
super(TextCNN, self).__init__()
self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
if config.use_pretrained_w2v:
self.embedding.weight.data.copy_(config.embedding_pretrained)
self.embedding.weight.requires_grad = True
self.convs = nn.ModuleList(
[nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size])
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes)
def forward(self, x):
x = self.embedding(x)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
x = torch.cat(x, 1)
x = self.dropout(x)
out = self.fc(x)
out = F.log_softmax(out, dim=1)
return out
class TextCNN_DataLoader(Dataset):
def __init__(self, train_data, labels):
self.train_data = train_data
self.labels = labels
def __len__(self):
return len(self.train_data)
def __getitem__(self, index):
sentence = np.array(self.train_data[index])
label = self.labels[index]
return sentence, label
class TextCNNHelp():
def __init__(self, config):
self.config = config
self.model = TextCNN(self.config)
self.model.to(self.config.device)
self.model_path = '/data/bigfiles/a4530bdd-0f97-44dd-ac9d-6b3da23a8acb.ckpt'
def load_data(self, X_train, y_train, X_dev, y_dev, X_test, y_test):
self.train_loader = DataLoader(TextCNN_DataLoader(X_train, y_train), batch_size=self.config.batch_size)
self.dev_loader = DataLoader(TextCNN_DataLoader(X_dev, y_dev), batch_size=self.config.batch_size)
self.test_loader = DataLoader(TextCNN_DataLoader(X_test, y_test), batch_size=self.config.batch_size)
self.optermizer = torch.optim.SGD(self.model.parameters(), lr=self.config.lr)
self.criterion = nn.CrossEntropyLoss()
self.scheduler = StepLR(self.optermizer, step_size=5)
def train_epoch(self, epoch):
self.model.train()
count = 0
correct = 0
loss_sum = 0
for i, (sentence, label) in enumerate(self.train_loader):
self.optermizer.zero_grad()
sentence = sentence.type(torch.LongTensor).to(self.config.device)
label = label.type(torch.LongTensor).to(self.config.device)
out = self.model(sentence)
# print('out: {}'.format(out.argmax(1)))
loss = self.criterion(out, label)
loss_sum += loss.item()
count += len(sentence)
correct += (out.argmax(1) == label).float().sum().item()
# print('correct / count: {}'.format(correct/ count))
if count % 200 == 0:
print('train epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, loss_sum / 100))
loss_sum = 0
loss.backward()
self.optermizer.step()
print('train epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count)))
self.scheduler.step()
torch.save(self.model.state_dict(), './{}.ckpt'.format(epoch))
def validation(self, epoch):
self.model.eval()
count, correct = 0, 0
val_loss_sum = 0
for i, (sentence, label) in enumerate(self.dev_loader):
sentence, label = sentence.to(self.config.device), label.to(self.config.device)
output = self.model(sentence)
loss = self.criterion(output, label)
val_loss_sum += loss.item()
correct += (output.argmax(1) == label).float().sum().item()
count += len(sentence)
if count % 200 == 0:
print('eval epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, val_loss_sum / 100))
val_loss_sum = 0
print('eval epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count)))
best_acc_list.append(100 * (correct / count))
def test(self):
model = TextCNN(self.config)
model.to(self.config.device)
model.load_state_dict(torch.load(self.model_path,map_location='cpu'))
correct = 0
count = 0
for i, (sentence, label) in enumerate(self.test_loader):
sentence, label = sentence.to(self.config.device), label.to(self.config.device)
output = model(sentence)
count += len(sentence)
correct += (output.argmax(1) == label).float().sum().item()
print('test acc: {}%'.format(100 * (correct / count)))
def train_model(self):
print('开始训练:')
epochs = self.config.epochs
for i in range(1, epochs + 1):
self.train_epoch(i)
self.validation(i)
model_path = self.model_path
print('开始测试:')
self.test()