You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

21 lines
821 B

import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, Linear
import config
class ADSB_GAT(torch.nn.Module):
def __init__(self, in_channels):
super().__init__()
torch.manual_seed(1234)
self.conv1 = GATConv(in_channels, config.HIDDEN_CHANNELS, heads=config.HEADS, dropout=config.DROPOUT)
self.conv2 = GATConv(config.HIDDEN_CHANNELS * config.HEADS, config.HIDDEN_CHANNELS, heads=1, dropout=config.DROPOUT)
self.classifier = Linear(config.HIDDEN_CHANNELS, 2) # 二分类
def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = F.elu(h)
h = F.dropout(h, p=config.DROPOUT, training=self.training)
h = self.conv2(h, edge_index)
h = F.elu(h)
out = self.classifier(h)
return out, h