forked from hnu202409060624/python
parent
c1370d0e81
commit
9ef2e7ccde
@ -1,16 +0,0 @@
|
|||||||
from torch import nn
|
|
||||||
import config
|
|
||||||
class PoswiseFeedForwardNet(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(PoswiseFeedForwardNet, self).__init__()
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(config.input_dim, config.d_ff1, bias=config.bias),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(config.d_ff1, config.input_dim, bias=config.bias))
|
|
||||||
|
|
||||||
def forward(self, inputs): # inputs: [batch_size, seq_len, d_model]
|
|
||||||
residual = inputs
|
|
||||||
|
|
||||||
output = self.fc(inputs)
|
|
||||||
|
|
||||||
return nn.LayerNorm(config.input_dim).to('cuda:0')(output + residual) # [batch_size, seq_len, d_model]
|
|
Loading…
Reference in new issue