|
|
|
@ -142,14 +142,14 @@ class Model(nn.Module):
|
|
|
|
|
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
|
|
|
|
|
|
|
|
|
|
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
|
|
|
|
print('Fusing layers...')
|
|
|
|
|
print('Fusing layers... ', end='')
|
|
|
|
|
for m in self.model.modules():
|
|
|
|
|
if type(m) is Conv:
|
|
|
|
|
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
|
|
|
m.bn = None # remove batchnorm
|
|
|
|
|
m.forward = m.fuseforward # update forward
|
|
|
|
|
torch_utils.model_info(self)
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def parse_model(md, ch): # model_dict, input_channels(3)
|
|
|
|
|
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
|
|
|
|