From a97c3f94ecb4bf818cdd2582c62de4f964ed76ee Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 16 Jul 2020 23:59:51 -0700 Subject: [PATCH] update common.py Classify() --- models/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index a23d73e..7a7272b 100644 --- a/models/common.py +++ b/models/common.py @@ -112,4 +112,5 @@ class Classify(nn.Module): self.flat = Flatten() def forward(self, x): - return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2) + z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list + return self.flat(self.conv(z)) # flatten to x(b,c2)