import torch import torch.nn.functional as F from torch_geometric.nn import GATConv, Linear import config class ADSB_GAT(torch.nn.Module): def __init__(self, in_channels): super().__init__() torch.manual_seed(1234) self.conv1 = GATConv(in_channels, config.HIDDEN_CHANNELS, heads=config.HEADS, dropout=config.DROPOUT) self.conv2 = GATConv(config.HIDDEN_CHANNELS * config.HEADS, config.HIDDEN_CHANNELS, heads=1, dropout=config.DROPOUT) self.classifier = Linear(config.HIDDEN_CHANNELS, 2) # 二分类 def forward(self, x, edge_index): h = self.conv1(x, edge_index) h = F.elu(h) h = F.dropout(h, p=config.DROPOUT, training=self.training) h = self.conv2(h, edge_index) h = F.elu(h) out = self.classifier(h) return out, h