from torch import nn import config class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc = nn.Sequential( nn.Linear(config.input_dim, config.d_ff1, bias=config.bias), nn.ReLU(), nn.Linear(config.d_ff1, config.input_dim, bias=config.bias)) def forward(self, inputs): # inputs: [batch_size, seq_len, d_model] residual = inputs output = self.fc(inputs) return nn.LayerNorm(config.input_dim).to(config.device)(output + residual) # [batch_size, seq_len, d_model]