diff --git a/models/experimental.py b/models/experimental.py index 539e7f9..146a61b 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -107,3 +107,15 @@ class MixConv2d(nn.Module): def forward(self, x): return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) + + +class Ensemble(nn.ModuleList): + # Ensemble of models + def __init__(self): + super(Ensemble, self).__init__() + + def forward(self, x, augment=False): + y = [] + for module in self: + y.append(module(x, augment)[0]) + return torch.cat(y, 1), None # ensembled inference output, train output