You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
2.9 KiB

import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from torch_geometric.loader import DataLoader
import config
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
return (self.alpha * (1-pt)**self.gamma * ce_loss).mean()
def train_model(model, train_graphs, test_graphs):
model = model.to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
criterion = FocalLoss()
train_loader = DataLoader(train_graphs, batch_size=config.BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=config.BATCH_SIZE, shuffle=False)
best_f1 = -1.0
print("Start training...\n")
for epoch in range(1, config.EPOCHS+1):
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(config.DEVICE)
optimizer.zero_grad()
out, _ = model(batch.x, batch.edge_index)
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
m = test(model, test_loader)
print(f"Epoch {epoch:02d} | Loss {total_loss:.3f} | Acc {m['acc']:.3f} | F1 {m['f1']:.3f} | Recall {m['recall']:.3f}")
if m["f1"] > best_f1:
best_f1 = m["f1"]
torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
print("\nTraining done!")
if best_f1 >= 0:
model.load_state_dict(torch.load(config.MODEL_SAVE_PATH))
print("\n===== BEST RESULT =====")
final = test(model, test_loader)
for k, v in final.items():
print(f"{k}: {v:.4f}")
@torch.no_grad()
def test(model, loader):
model.eval()
preds, labels = [], []
for batch in loader:
batch = batch.to(config.DEVICE)
out, _ = model(batch.x, batch.edge_index)
m = batch.test_mask
p = out[m].argmax(dim=1)
preds.append(p.cpu().numpy())
labels.append(batch.y[m].cpu().numpy())
preds = np.concatenate(preds)
labels = np.concatenate(labels)
acc = accuracy_score(labels, preds)
precision = precision_score(labels, preds, zero_division=0)
recall = recall_score(labels, preds, zero_division=0)
f1 = f1_score(labels, preds, zero_division=0)
print("Sample predictions:", preds[:20]) # 打印前20个预测
print("Sample labels:", labels[:20]) # 打印前20个真实标签
return {"acc": acc, "precision": precision, "recall": recall, "f1": f1}