diff --git a/Feed_Forward.py b/Feed_Forward.py new file mode 100644 index 0000000..d1c2b41 --- /dev/null +++ b/Feed_Forward.py @@ -0,0 +1,16 @@ +from torch import nn +import config +class PoswiseFeedForwardNet(nn.Module): + def __init__(self): + super(PoswiseFeedForwardNet, self).__init__() + self.fc = nn.Sequential( + nn.Linear(config.input_dim, config.d_ff1, bias=config.bias), + nn.ReLU(), + nn.Linear(config.d_ff1, config.input_dim, bias=config.bias)) + + def forward(self, inputs): # inputs: [batch_size, seq_len, d_model] + residual = inputs + + output = self.fc(inputs) + + return nn.LayerNorm(config.input_dim).to('cuda:0')(output + residual) # [batch_size, seq_len, d_model]