diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index c126f894..569995d8 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -97,6 +97,10 @@ def _normalize(self, group_importance, normalizer): return group_importance / group_importance.max() elif normalizer == 'gaussian': return (group_importance - group_importance.mean()) / (group_importance.std()+1e-8) + elif normalizer.startswith('sentinel'): # normalize the score with the k-th smallest element. e.g. sentinel_0.5 means median normalization + sentinel = float(normalizer.split('_')[1]) * len(group_importance) + sentinel = torch.argsort(group_importance, dim=0, descending=False)[int(sentinel)] + return group_importance / (group_importance[sentinel]+1e-8) elif normalizer=='lamp': return self._lamp(group_importance) else: