From 38f5c1ad1d0f4d391544e302498eb81ff234e175 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 5 Jul 2020 13:41:21 -0700 Subject: [PATCH] pruning and sparsity initial commit --- models/yolo.py | 2 +- utils/torch_utils.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/models/yolo.py b/models/yolo.py index b2d09cc..9617f5b 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -48,7 +48,7 @@ class Model(nn.Module): if type(model_cfg) is dict: self.md = model_cfg # model dict else: # is *.yaml - import yaml + import yaml # for torch hub with open(model_cfg) as f: self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict diff --git a/utils/torch_utils.py b/utils/torch_utils.py index fd00b8b..35ef011 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d): return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] +def sparsity(model): + # Return global model sparsity + a, b = 0., 0. + for p in model.parameters(): + a += p.numel() + b += (p == 0).sum() + return b / a + + +def prune(model, amount=0.3): + # Prune model to requested global sparsity + import torch.nn.utils.prune as prune + print('Pruning model... ', end='') + for name, m in model.named_modules(): + if isinstance(m, torch.nn.Conv2d): + prune.l1_unstructured(m, name='weight', amount=amount) # prune + prune.remove(m, 'weight') # make permanent + print(' %.3g global sparsity' % sparsity(model)) + + def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad():