You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

17 lines
623 B

from torch import nn
import config
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc = nn.Sequential(
nn.Linear(config.input_dim, config.d_ff1, bias=config.bias),
nn.ReLU(),
nn.Linear(config.d_ff1, config.input_dim, bias=config.bias))
def forward(self, inputs): # inputs: [batch_size, seq_len, d_model]
residual = inputs
output = self.fc(inputs)
return nn.LayerNorm(config.input_dim).to(config.device)(output + residual) # [batch_size, seq_len, d_model]