|
|
@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d):
|
|
|
|
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
|
|
|
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sparsity(model):
|
|
|
|
|
|
|
|
# Return global model sparsity
|
|
|
|
|
|
|
|
a, b = 0., 0.
|
|
|
|
|
|
|
|
for p in model.parameters():
|
|
|
|
|
|
|
|
a += p.numel()
|
|
|
|
|
|
|
|
b += (p == 0).sum()
|
|
|
|
|
|
|
|
return b / a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prune(model, amount=0.3):
|
|
|
|
|
|
|
|
# Prune model to requested global sparsity
|
|
|
|
|
|
|
|
import torch.nn.utils.prune as prune
|
|
|
|
|
|
|
|
print('Pruning model... ', end='')
|
|
|
|
|
|
|
|
for name, m in model.named_modules():
|
|
|
|
|
|
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
|
|
|
|
|
|
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
|
|
|
|
|
|
|
prune.remove(m, 'weight') # make permanent
|
|
|
|
|
|
|
|
print(' %.3g global sparsity' % sparsity(model))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fuse_conv_and_bn(conv, bn):
|
|
|
|
def fuse_conv_and_bn(conv, bn):
|
|
|
|
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
|
|
|
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|