pruning and sparsity initial commit

pull/1/head
Glenn Jocher 5 years ago
parent 997ba7b346
commit 38f5c1ad1d

@ -48,7 +48,7 @@ class Model(nn.Module):
if type(model_cfg) is dict:
self.md = model_cfg # model dict
else: # is *.yaml
import yaml
import yaml # for torch hub
with open(model_cfg) as f:
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict

@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d):
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
def sparsity(model):
# Return global model sparsity
a, b = 0., 0.
for p in model.parameters():
a += p.numel()
b += (p == 0).sum()
return b / a
def prune(model, amount=0.3):
# Prune model to requested global sparsity
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
if isinstance(m, torch.nn.Conv2d):
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))
def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():

Loading…
Cancel
Save