You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.9 KiB
60 lines
1.9 KiB
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from typing import Type
|
|
|
|
class Adapter(nn.Module):
|
|
def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
|
|
super().__init__()
|
|
self.skip_connect = skip_connect
|
|
D_hidden_features = int(D_features * mlp_ratio)
|
|
self.act = act_layer()
|
|
self.D_fc1 = nn.Linear(D_features, D_hidden_features)
|
|
self.D_fc2 = nn.Linear(D_hidden_features, D_features)
|
|
|
|
def forward(self, x):
|
|
# x is (BT, HW+1, D)
|
|
xs = self.D_fc1(x)
|
|
xs = self.act(xs)
|
|
xs = self.D_fc2(xs)
|
|
if self.skip_connect:
|
|
x = x + xs
|
|
else:
|
|
x = xs
|
|
return x
|
|
|
|
|
|
class MLPBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
mlp_dim: int,
|
|
act: Type[nn.Module] = nn.GELU,
|
|
) -> None:
|
|
super().__init__()
|
|
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
|
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
|
self.act = act()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.lin2(self.act(self.lin1(x)))
|
|
|
|
|
|
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
|
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
|
class LayerNorm2d(nn.Module):
|
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
self.eps = eps
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
u = x.mean(1, keepdim=True)
|
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
return x
|