You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
21 lines
821 B
21 lines
821 B
import torch
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import GATConv, Linear
|
|
import config
|
|
|
|
class ADSB_GAT(torch.nn.Module):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
torch.manual_seed(1234)
|
|
self.conv1 = GATConv(in_channels, config.HIDDEN_CHANNELS, heads=config.HEADS, dropout=config.DROPOUT)
|
|
self.conv2 = GATConv(config.HIDDEN_CHANNELS * config.HEADS, config.HIDDEN_CHANNELS, heads=1, dropout=config.DROPOUT)
|
|
self.classifier = Linear(config.HIDDEN_CHANNELS, 2) # 二分类
|
|
|
|
def forward(self, x, edge_index):
|
|
h = self.conv1(x, edge_index)
|
|
h = F.elu(h)
|
|
h = F.dropout(h, p=config.DROPOUT, training=self.training)
|
|
h = self.conv2(h, edge_index)
|
|
h = F.elu(h)
|
|
out = self.classifier(h)
|
|
return out, h |