diff --git a/utils/activations.py b/utils/activations.py index cf226fe..da3a4c3 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -61,3 +61,16 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish @staticmethod def forward(x): return x * F.softplus(x).tanh() + + +# FReLU https://arxiv.org/abs/2007.11824 -------------------------------------- +class FReLU(nn.Module): + def __init__(self, c1, k=3): # ch_in, kernel + super().__init()__() + self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1) + self.bn = nn.BatchNorm2d(c1) + + @staticmethod + def forward(self, x): + return torch.max(x, self.bn(self.conv(x))) +