import json import torch import torch.nn as nn import numpy as np from torch.optim.lr_scheduler import StepLR from torch.utils.data import Dataset from torch.utils.data import DataLoader import torch import torch.nn as nn import torch.nn.functional as F best_acc_list = [] class TextCNN(nn.Module): def __init__(self, config): super(TextCNN, self).__init__() self.embedding = nn.Embedding(config.vocab_size, config.embedding_size) if config.use_pretrained_w2v: self.embedding.weight.data.copy_(config.embedding_pretrained) self.embedding.weight.requires_grad = True self.convs = nn.ModuleList( [nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size]) self.dropout = nn.Dropout(config.dropout) self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes) def forward(self, x): x = self.embedding(x) x = x.unsqueeze(1) x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x] x = torch.cat(x, 1) x = self.dropout(x) out = self.fc(x) out = F.log_softmax(out, dim=1) return out class TextCNN_DataLoader(Dataset): def __init__(self, train_data, labels): self.train_data = train_data self.labels = labels def __len__(self): return len(self.train_data) def __getitem__(self, index): sentence = np.array(self.train_data[index]) label = self.labels[index] return sentence, label class TextCNNHelp(): def __init__(self, config): self.config = config self.model = TextCNN(self.config) self.model.to(self.config.device) self.model_path = '/data/bigfiles/a4530bdd-0f97-44dd-ac9d-6b3da23a8acb.ckpt' def load_data(self, X_train, y_train, X_dev, y_dev, X_test, y_test): self.train_loader = DataLoader(TextCNN_DataLoader(X_train, y_train), batch_size=self.config.batch_size) self.dev_loader = DataLoader(TextCNN_DataLoader(X_dev, y_dev), batch_size=self.config.batch_size) self.test_loader = DataLoader(TextCNN_DataLoader(X_test, y_test), batch_size=self.config.batch_size) self.optermizer = torch.optim.SGD(self.model.parameters(), lr=self.config.lr) self.criterion = nn.CrossEntropyLoss() self.scheduler = StepLR(self.optermizer, step_size=5) def train_epoch(self, epoch): self.model.train() count = 0 correct = 0 loss_sum = 0 for i, (sentence, label) in enumerate(self.train_loader): self.optermizer.zero_grad() sentence = sentence.type(torch.LongTensor).to(self.config.device) label = label.type(torch.LongTensor).to(self.config.device) out = self.model(sentence) # print('out: {}'.format(out.argmax(1))) loss = self.criterion(out, label) loss_sum += loss.item() count += len(sentence) correct += (out.argmax(1) == label).float().sum().item() # print('correct / count: {}'.format(correct/ count)) if count % 200 == 0: print('train epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, loss_sum / 100)) loss_sum = 0 loss.backward() self.optermizer.step() print('train epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count))) self.scheduler.step() torch.save(self.model.state_dict(), './{}.ckpt'.format(epoch)) def validation(self, epoch): self.model.eval() count, correct = 0, 0 val_loss_sum = 0 for i, (sentence, label) in enumerate(self.dev_loader): sentence, label = sentence.to(self.config.device), label.to(self.config.device) output = self.model(sentence) loss = self.criterion(output, label) val_loss_sum += loss.item() correct += (output.argmax(1) == label).float().sum().item() count += len(sentence) if count % 200 == 0: print('eval epoch: {}, step: {}, loss: {:.5f}'.format(epoch, i + 1, val_loss_sum / 100)) val_loss_sum = 0 print('eval epoch: {}, train_acc: {}%'.format(epoch, 100 * (correct / count))) best_acc_list.append(100 * (correct / count)) def test(self): model = TextCNN(self.config) model.to(self.config.device) model.load_state_dict(torch.load(self.model_path,map_location='cpu')) correct = 0 count = 0 for i, (sentence, label) in enumerate(self.test_loader): sentence, label = sentence.to(self.config.device), label.to(self.config.device) output = model(sentence) count += len(sentence) correct += (output.argmax(1) == label).float().sum().item() print('test acc: {}%'.format(100 * (correct / count))) def train_model(self): print('开始训练:') epochs = self.config.epochs for i in range(1, epochs + 1): self.train_epoch(i) self.validation(i) model_path = self.model_path print('开始测试:') self.test()