You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.9 KiB
80 lines
2.9 KiB
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
|
from torch_geometric.loader import DataLoader
|
|
import config
|
|
|
|
class FocalLoss(torch.nn.Module):
|
|
def __init__(self, alpha=0.25, gamma=2):
|
|
super().__init__()
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
|
|
def forward(self, inputs, targets):
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
|
pt = torch.exp(-ce_loss)
|
|
return (self.alpha * (1-pt)**self.gamma * ce_loss).mean()
|
|
|
|
def train_model(model, train_graphs, test_graphs):
|
|
model = model.to(config.DEVICE)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
|
|
criterion = FocalLoss()
|
|
|
|
train_loader = DataLoader(train_graphs, batch_size=config.BATCH_SIZE, shuffle=True)
|
|
test_loader = DataLoader(test_graphs, batch_size=config.BATCH_SIZE, shuffle=False)
|
|
|
|
best_f1 = -1.0
|
|
print("Start training...\n")
|
|
|
|
for epoch in range(1, config.EPOCHS+1):
|
|
model.train()
|
|
total_loss = 0
|
|
|
|
for batch in train_loader:
|
|
batch = batch.to(config.DEVICE)
|
|
optimizer.zero_grad()
|
|
out, _ = model(batch.x, batch.edge_index)
|
|
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
|
|
if epoch % 10 == 0:
|
|
m = test(model, test_loader)
|
|
print(f"Epoch {epoch:02d} | Loss {total_loss:.3f} | Acc {m['acc']:.3f} | F1 {m['f1']:.3f} | Recall {m['recall']:.3f}")
|
|
|
|
if m["f1"] > best_f1:
|
|
best_f1 = m["f1"]
|
|
torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
|
|
|
|
print("\nTraining done!")
|
|
if best_f1 >= 0:
|
|
model.load_state_dict(torch.load(config.MODEL_SAVE_PATH))
|
|
print("\n===== BEST RESULT =====")
|
|
final = test(model, test_loader)
|
|
for k, v in final.items():
|
|
print(f"{k}: {v:.4f}")
|
|
|
|
@torch.no_grad()
|
|
def test(model, loader):
|
|
model.eval()
|
|
preds, labels = [], []
|
|
for batch in loader:
|
|
batch = batch.to(config.DEVICE)
|
|
out, _ = model(batch.x, batch.edge_index)
|
|
m = batch.test_mask
|
|
p = out[m].argmax(dim=1)
|
|
preds.append(p.cpu().numpy())
|
|
labels.append(batch.y[m].cpu().numpy())
|
|
|
|
preds = np.concatenate(preds)
|
|
labels = np.concatenate(labels)
|
|
|
|
acc = accuracy_score(labels, preds)
|
|
precision = precision_score(labels, preds, zero_division=0)
|
|
recall = recall_score(labels, preds, zero_division=0)
|
|
f1 = f1_score(labels, preds, zero_division=0)
|
|
print("Sample predictions:", preds[:20]) # 打印前20个预测
|
|
print("Sample labels:", labels[:20]) # 打印前20个真实标签
|
|
return {"acc": acc, "precision": precision, "recall": recall, "f1": f1} |