|
|
@ -107,3 +107,15 @@ class MixConv2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
|
|
|
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Ensemble(nn.ModuleList):
|
|
|
|
|
|
|
|
# Ensemble of models
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
super(Ensemble, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, augment=False):
|
|
|
|
|
|
|
|
y = []
|
|
|
|
|
|
|
|
for module in self:
|
|
|
|
|
|
|
|
y.append(module(x, augment)[0])
|
|
|
|
|
|
|
|
return torch.cat(y, 1), None # ensembled inference output, train output
|
|
|
|