update experimental.py with Ensemble() module

pull/1/head
Glenn Jocher 5 years ago
parent 38f5c1ad1d
commit 5ba1de0cdc

@ -107,3 +107,15 @@ class MixConv2d(nn.Module):
def forward(self, x): def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
class Ensemble(nn.ModuleList):
# Ensemble of models
def __init__(self):
super(Ensemble, self).__init__()
def forward(self, x, augment=False):
y = []
for module in self:
y.append(module(x, augment)[0])
return torch.cat(y, 1), None # ensembled inference output, train output

Loading…
Cancel
Save