diff --git a/models/common.py b/models/common.py index 2c2d600..a23d73e 100644 --- a/models/common.py +++ b/models/common.py @@ -76,12 +76,6 @@ class SPP(nn.Module): return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) -class Flatten(nn.Module): - # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions - def forward(self, x): - return x.view(x.size(0), -1) - - class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups @@ -100,3 +94,22 @@ class Concat(nn.Module): def forward(self, x): return torch.cat(x, self.d) + + +class Flatten(nn.Module): + # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions + @staticmethod + def forward(x): + return x.view(x.size(0), -1) + + +class Classify(nn.Module): + # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + super(Classify, self).__init__() + self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1) + self.flat = Flatten() + + def forward(self, x): + return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2)