diff --git a/models/yolo.py b/models/yolo.py index b2d09cc..9617f5b 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -48,7 +48,7 @@ class Model(nn.Module): if type(model_cfg) is dict: self.md = model_cfg # model dict else: # is *.yaml - import yaml + import yaml # for torch hub with open(model_cfg) as f: self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict diff --git a/utils/torch_utils.py b/utils/torch_utils.py index fd00b8b..35ef011 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d): return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] +def sparsity(model): + # Return global model sparsity + a, b = 0., 0. + for p in model.parameters(): + a += p.numel() + b += (p == 0).sum() + return b / a + + +def prune(model, amount=0.3): + # Prune model to requested global sparsity + import torch.nn.utils.prune as prune + print('Pruning model... ', end='') + for name, m in model.named_modules(): + if isinstance(m, torch.nn.Conv2d): + prune.l1_unstructured(m, name='weight', amount=amount) # prune + prune.remove(m, 'weight') # make permanent + print(' %.3g global sparsity' % sparsity(model)) + + def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad():