diff --git a/README.md b/README.md
index de5f2718..d7fa4dda 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@ For more technical details, please refer to our CVPR'23 paper:
Please do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper.
Or Join our Discord or WeChat group for a chat:
* Discord: [link](https://discord.gg/Pvd6hbYXRs)
- * WeChat Group (Group size exceeded 400): [QR Code](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3)
+ * WeChat Group [Group 1 (500/500, FULL)](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3), [Group-2](https://github.com/VainF/Torch-Pruning/assets/18592211/4e5f98e9-86b6-46bd-9e9f-3275c5ccc2f4)
## Table of Contents
- [Installation](#installation)
@@ -424,7 +424,7 @@ Please refer to [benchmarks](benchmarks) for more details.
> **DeepCache: Accelerating Diffusion Models for Free** [[Project]](https://github.com/horseee/DeepCache) [[Arxiv]](https://arxiv.org/abs/2312.00858)
> *Xinyin Ma, Gongfan Fang, and Xinchao Wang*
-> Preprint 2023
+> CVPR 2024
> **0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)
> *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang*
diff --git a/examples/sparse_training/benchmark_importance_criteria.py b/examples/sparse_training/benchmark_importance_criteria.py
index c26b9e8d..c20cacfd 100644
--- a/examples/sparse_training/benchmark_importance_criteria.py
+++ b/examples/sparse_training/benchmark_importance_criteria.py
@@ -50,8 +50,8 @@ def validate_model(model, val_loader):
# Importance criteria
imp_dict = {
- 'Group Hessian': tp.importance.OBDImportance(group_reduction='mean'),
- 'Single-layer Hessian': tp.importance.OBDImportance(group_reduction='first'),
+ 'Group OBD': tp.importance.OBDImportance(group_reduction='mean'),
+ 'Single-layer OBD': tp.importance.OBDImportance(group_reduction='first'),
'Group Taylor': tp.importance.TaylorImportance(group_reduction='mean'),
'Single-layer Taylor': tp.importance.TaylorImportance(group_reduction='first'),
diff --git a/examples/sparse_training/main.py b/examples/sparse_training/main.py
index e448baf6..092728b8 100644
--- a/examples/sparse_training/main.py
+++ b/examples/sparse_training/main.py
@@ -18,6 +18,7 @@
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--dataset", type=str, default="cifar100", choices=['cifar10', 'cifar100', 'modelnet40'])
+parser.add_argument('--dataroot', default='data', help='path to your datasets')
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--total-epochs", type=int, default=100)
parser.add_argument("--lr-decay-milestones", default="60,80", type=str, help="milestones for learning rate decay")
@@ -25,6 +26,7 @@
parser.add_argument("--lr", default=0.01, type=float, help="learning rate")
parser.add_argument("--restore", type=str, default=None)
parser.add_argument('--output-dir', default='run', help='path where to save')
+parser.add_argument("--finetune", action="store_true", default=False, help='whether finetune or not')
# For pruning
parser.add_argument("--method", type=str, default=None)
@@ -46,17 +48,34 @@
args = parser.parse_args()
-def progressive_pruning(pruner, model, speed_up, example_inputs):
+def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
model.eval()
base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = 1
while current_speed_up < speed_up:
- pruner.step(interactive=False)
+ if args.method == "obdc":
+ model.zero_grad()
+ imp=pruner.importance
+ imp._prepare_model(model, pruner)
+ for k, (imgs, lbls) in enumerate(train_loader):
+ if k>=10: break
+ imgs = imgs.to(args.device)
+ lbls = lbls.to(args.device)
+ output = model(imgs)
+ sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1),
+ 1).squeeze().to(args.device)
+ loss_sample = F.cross_entropy(output, sampled_y)
+ loss_sample.backward()
+ imp.step()
+ pruner.step()
+ imp._rm_hooks(model)
+ imp._clear_buffer()
+ else:
+ pruner.step()
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = float(base_ops) / pruned_ops
if pruner.current_step == pruner.iterative_steps:
break
- #print(current_speed_up)
return current_speed_up
def eval(model, test_loader, device=None):
@@ -169,9 +188,18 @@ def get_pruner(model, example_inputs):
elif args.method == "l1":
imp = tp.importance.MagnitudeImportance(p=1)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
+ elif args.method == "l2":
+ imp = tp.importance.MagnitudeImportance(p=2)
+ pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
+ elif args.method == "fpgm":
+ imp = tp.importance.FPGMImportance(p=2)
+ pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
+ elif args.method == "obdc":
+ imp = tp.importance.OBDCImportance(group_reduction='mean', num_classes=args.num_classes)
+ pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "lamp":
imp = tp.importance.LAMPImportance(p=2)
- pruner_entry = partial(tp.pruner.BNScalePruner, global_pruning=args.global_pruning)
+ pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "slim":
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
@@ -242,7 +270,7 @@ def main():
# Model & Dataset
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes, train_dst, val_dst, input_size = registry.get_dataset(
- args.dataset, data_root="data"
+ args.dataset, data_root=args.dataroot
)
args.num_classes = num_classes
model = registry.get_model(args.model, num_classes=num_classes, pretrained=True, target_dataset=args.dataset)
@@ -315,7 +343,7 @@ def main():
ori_ops, ori_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
ori_acc, ori_val_loss = eval(model, test_loader, device=args.device)
args.logger.info("Pruning...")
- progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs)
+ progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs, train_loader=train_loader)
del pruner # remove reference
args.logger.info(model)
pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
@@ -340,17 +368,18 @@ def main():
)
# 2. Finetuning
- args.logger.info("Finetuning...")
- train_model(
- model,
- epochs=args.total_epochs,
- lr=args.lr,
- lr_decay_milestones=args.lr_decay_milestones,
- train_loader=train_loader,
- test_loader=test_loader,
- device=args.device,
- save_state_dict_only=False,
- )
+ if args.finetune:
+ args.logger.info("Finetuning...")
+ train_model(
+ model,
+ epochs=args.total_epochs,
+ lr=args.lr,
+ lr_decay_milestones=args.lr_decay_milestones,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ device=args.device,
+ save_state_dict_only=False,
+ )
elif args.mode == "test":
model.eval()
ops, params = tp.utils.count_ops_and_params(
diff --git a/examples/sparse_training/tools/draw.py b/examples/sparse_training/tools/draw.py
index 317705c1..aec999d8 100644
--- a/examples/sparse_training/tools/draw.py
+++ b/examples/sparse_training/tools/draw.py
@@ -7,8 +7,8 @@
plt.style.use('bmh')
color_dict = {
- 'Group Hessian': "C0",
- 'Single-layer Hessian': "C0",
+ 'Group OBD': "C0",
+ 'Single-layer OBD': "C0",
'Random': "C1",
diff --git a/examples/transformers/draw_acc_curve.py b/examples/transformers/draw_acc_curve.py
index f5ec6dff..67c9cf86 100644
--- a/examples/transformers/draw_acc_curve.py
+++ b/examples/transformers/draw_acc_curve.py
@@ -13,7 +13,7 @@ def parse_acc_from_file(file_path):
return acc
log_dict = {
- 'Hessian-uniform': 'output/vit_b_16_pruning_hessian_uniform/train.log',
+ 'OBD-uniform': 'output/vit_b_16_pruning_OBD_uniform/train.log',
'Taylor-uniform': 'output/vit_b_16_pruning_taylor_uniform/train.log',
'Taylor-bottleneck': 'output/vit_b_16_pruning_taylor_bottleneck/train.log',
'L1-uniform': 'output/vit_b_16_pruning_l1_uniform/train.log',
diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py
index 364a4ce1..6b062c90 100644
--- a/examples/transformers/prune_timm_vit.py
+++ b/examples/transformers/prune_timm_vit.py
@@ -20,7 +20,7 @@ def parse_args():
parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio')
parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform')
- parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'hessian'])
+ parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'OBD'])
parser.add_argument('--test_accuracy', default=False, action='store_true', help='test accuracy')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--prune_num_heads', default=False, action='store_true', help='global pruning')
@@ -111,11 +111,11 @@ def main():
imp = tp.importance.GroupNormImportance(p=2)
elif args.pruning_type == 'l1':
imp = tp.importance.GroupNormImportance(p=1)
- elif args.pruning_type == 'hessian':
+ elif args.pruning_type == 'OBD':
imp = tp.importance.GroupOBDImportance()
else: raise NotImplementedError
- if args.pruning_type in ['taylor', 'hessian'] or args.test_accuracy:
+ if args.pruning_type in ['taylor', 'OBD'] or args.test_accuracy:
train_loader, val_loader = prepare_imagenet(args.data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, use_imagenet_mean_std=args.use_imagenet_mean_std)
# Load the model
diff --git a/examples/transformers/readme.md b/examples/transformers/readme.md
index d5e517df..6e098520 100644
--- a/examples/transformers/readme.md
+++ b/examples/transformers/readme.md
@@ -43,13 +43,13 @@ bash scripts/finetune_timm_vit_b_16_taylor_uniform.sh
```
Pruning results for ImageNet-21K-ft-1K (Timm):
-| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group Hessian (Uniform) |
+| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group OBD (Uniform) |
| :-- | :--: | :--: | :--: | :--: | :--: | :--: |
| **#Params** | 86.57 M | 88.22 M | 22.05 M | 22.05 M | 24.83 M | 22.05 M |
| **MACs** | 17.59 G | 4.41 G | 4.61 G | 4.61 G | 4.62 G | 4.61 G |
| **Acc @ Epoch 300** | 85.21 | 80.68 | 78.11 | 80.19 | 80.06 | 80.15 |
| **Latency (Bs=1, A5000)** | 5.21 ms
+- 0.05 ms | 3.87 ms
+- 0.05 ms | 3.99 ms
+- 0.10 ms | 3.99 ms
+- 0.10 ms | 3.87 ms
+- 0.14 ms | 3.99 ms
+- 0.10 ms |
-| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_hessian_uniform.pth) |
+| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_OBD_uniform.pth) |
*Notes:*
* Uniform - We apply the same pruning ratio to all layers.
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
index 50012b2c..e210f1cd 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
@@ -1,5 +1,5 @@
torchrun --nproc_per_node=8 finetune.py \
- --model "output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth" \
+ --model "output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth" \
--epochs 300 \
--batch-size 256 \
--opt adamw \
@@ -17,4 +17,4 @@ torchrun --nproc_per_node=8 finetune.py \
--ra-sampler \
--cutmix-alpha 1.0 \
--data-path "data/imagenet" \
- --output-dir output/vit_b_16_pruning_hessian_uniform
\ No newline at end of file
+ --output-dir output/vit_b_16_pruning_OBD_uniform
\ No newline at end of file
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
index 39e35671..bf3c7e55 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
@@ -1,10 +1,10 @@
python prune_timm_vit.py \
--model_name vit_base_patch16_224 \
- --pruning_type hessian \
+ --pruning_type OBD \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
--test_accuracy \
--data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
- --save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \
\ No newline at end of file
+ --save_as output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth \
\ No newline at end of file
diff --git a/tests/test_hessian_importance.py b/tests/test_hessian_importance.py
index 2037ba83..b25887f8 100644
--- a/tests/test_hessian_importance.py
+++ b/tests/test_hessian_importance.py
@@ -2,7 +2,7 @@
from torchvision.models import resnet18
import torch_pruning as tp
-def test_hessian():
+def test_OBD():
model = resnet18(pretrained=True)
# Importance criteria
@@ -52,4 +52,4 @@ def test_hessian():
# ...
if __name__=="__main__":
- test_hessian()
\ No newline at end of file
+ test_OBD()
\ No newline at end of file
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index d7703666..b9f48788 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -18,8 +18,15 @@ def test_pruner():
[tp.importance.GroupNormImportance, tp.pruner.GroupNormPruner],
[tp.importance.BNScaleImportance, tp.pruner.BNScalePruner],
[tp.importance.GroupNormImportance, tp.pruner.GrowingRegPruner],
+ [tp.importance.MagnitudeImportance, tp.pruner.GroupNormPruner],
+ [tp.importance.LAMPImportance, tp.pruner.GroupNormPruner],
+ [tp.importance.OBDCImportance, tp.pruner.GroupNormPruner],
+ [tp.importance.FPGMImportance, tp.pruner.GroupNormPruner],
]:
- imp = imp_cls()
+ if imp_cls == tp.importance.OBDCImportance:
+ imp = imp_cls(num_classes=1000)
+ else:
+ imp = imp_cls()
ignored_layer_outputs = []
# DO NOT prune the final classifier!
for m in model.modules():
@@ -37,7 +44,12 @@ def test_pruner():
)
for i in range(iterative_steps):
- model(example_inputs).sum().backward()
+ if isinstance(imp, tp.importance.OBDCImportance):
+ imp._prepare_model(model, pruner)
+ model(example_inputs).sum().backward()
+ imp.step()
+ else:
+ model(example_inputs).sum().backward()
grad_dict = {}
for p in model.parameters():
if p.grad is not None:
@@ -53,6 +65,9 @@ def test_pruner():
print(name, "has no grad")
for g in pruner.step(interactive=True):
g.prune()
+ if isinstance(imp, tp.importance.OBDCImportance):
+ imp._rm_hooks(model)
+ imp._clear_buffer()
if __name__ == "__main__":
diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py
index 2015ebad..9ea2335e 100644
--- a/torch_pruning/dependency.py
+++ b/torch_pruning/dependency.py
@@ -176,7 +176,7 @@ def prune(self, idxs=None, record_history=True):
self._DG._param_to_name[pruned_parameter] = name
self._DG.module2node[pruned_parameter] = self._DG.module2node.pop(old_parameter)
self._DG.module2node[pruned_parameter].module = pruned_parameter
- else: # prune nn.Module
+ else:
dep(idxs)
if record_history:
diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py
index 4cf3143d..d36ecc3a 100644
--- a/torch_pruning/pruner/algorithms/metapruner.py
+++ b/torch_pruning/pruner/algorithms/metapruner.py
@@ -2,6 +2,8 @@
import torch.nn as nn
import typing, warnings
+from torch_pruning.pruner.importance import OBDCImportance
+
from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency
@@ -239,6 +241,8 @@ def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
else:
for group in pruning_method():
group.prune()
+ # print("gg")
+ # exit(0)
def estimate_importance(self, group) -> torch.Tensor:
return self.importance(group)
@@ -413,6 +417,9 @@ def prune_local(self) -> typing.Generator:
if len(pruning_idxs)==0: continue
pruning_idxs = torch.unique( torch.cat(pruning_idxs, 0) ).tolist()
+ if isinstance(self.importance, OBDCImportance):
+ self.importance.adjust_fisher(group, pruning_idxs)
+
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs)
@@ -542,6 +549,8 @@ def prune_global(self) -> typing.Generator:
if len(pruning_indices)==0: continue
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()
+ if isinstance(self.importance, OBDCImportance):
+ self.importance.adjust_fisher(group, pruning_indices)
# create pruning group
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices)
diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py
index 2465e9c4..c7f8c5d4 100644
--- a/torch_pruning/pruner/importance.py
+++ b/torch_pruning/pruner/importance.py
@@ -5,6 +5,12 @@
import typing
from . import function
from ..dependency import Group
+from .._helpers import _FlattenIndexMapping
+from .. import ops
+import math
+import numpy as np
+from collections import OrderedDict
+from ..utils.compute_mat_grad import ComputeMatGrad
__all__ = [
# Base Class
@@ -31,7 +37,6 @@ class Importance(abc.ABC):
It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels.
All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups.
- Just ignore the ch_groups if you are not familar with grouping.
Example:
```python
@@ -90,13 +95,23 @@ def __init__(self,
self.target_types = target_types
self.bias = bias
- def _lamp(self, imp): # Layer-adaptive Sparsity for the Magnitude-based Pruning
- argsort_idx = torch.argsort(imp, dim=0, descending=True)
- sorted_imp = imp[argsort_idx.tolist()]
- cumsum_imp = torch.cumsum(sorted_imp, dim=0)
- sorted_imp = sorted_imp / cumsum_imp
- inversed_idx = torch.argsort(argsort_idx).tolist() # [0, 1, 2, 3, ..., ]
- return sorted_imp[inversed_idx]
+ def _lamp(self, scores): # Layer-adaptive Sparsity for the Magnitude-based Pruning
+ """
+ Normalizing scheme for LAMP.
+ """
+ # sort scores in an ascending order
+ sorted_scores,sorted_idx = scores.view(-1).sort(descending=False)
+ # compute cumulative sum
+ scores_cumsum_temp = sorted_scores.cumsum(dim=0)
+ scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device)
+ scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1]
+ # normalize by cumulative sum
+ sorted_scores /= (scores.sum() - scores_cumsum)
+ # tidy up and output
+ new_scores = torch.zeros(scores_cumsum.shape,device=scores.device)
+ new_scores[sorted_idx] = sorted_scores
+
+ return new_scores.view(scores.shape)
def _normalize(self, group_importance, normalizer):
if normalizer is None:
@@ -295,6 +310,76 @@ def __init__(self, p=2, group_reduction="mean", normalizer='lamp', bias=False):
assert normalizer == 'lamp'
super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer, bias=bias)
+
+class FPGMImportance(GroupNormImportance):
+ """Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
+ http://openaccess.thecvf.com/content_CVPR_2019/papers/He_Filter_Pruning_via_Geometric_Median_for_Deep_Convolutional_Neural_Networks_CVPR_2019_paper.pdf
+ """
+
+ def __init__(self, p=2, group_reduction="mean", normalizer='mean', bias=False):
+ super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer, bias=bias)
+
+ @torch.no_grad()
+ def __call__(self, group, **kwargs):
+ group_imp = []
+ group_idxs = []
+ # Iterate over all groups and estimate group importance
+ for i, (dep, idxs) in enumerate(group):
+ layer = dep.layer
+ prune_fn = dep.pruning_fn
+ root_idxs = group[i].root_idxs
+ if not isinstance(layer, tuple(self.target_types)):
+ continue
+ ####################
+ # Conv/Linear Output
+ ####################
+ if prune_fn in [
+ function.prune_conv_out_channels,
+ function.prune_linear_out_channels,
+ ]:
+ if hasattr(layer, "transposed") and layer.transposed:
+ w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
+ else:
+ w = layer.weight.data[idxs].flatten(1)
+ local_imp = w.abs().pow(self.p)
+ # calculate the euclidean distance as similarity
+ similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0)
+ similar_sum = torch.sum(torch.abs(similar_matrix), dim=0)
+ group_imp.append(similar_sum)
+ group_idxs.append(root_idxs)
+
+ ####################
+ # Conv/Linear Input
+ ####################
+ elif prune_fn in [
+ function.prune_conv_in_channels,
+ function.prune_linear_in_channels,
+ ]:
+ if hasattr(layer, "transposed") and layer.transposed:
+ w = (layer.weight.data).flatten(1)
+ else:
+ w = (layer.weight.data).transpose(0, 1).flatten(1)
+
+ local_imp = w.abs().pow(self.p)
+
+ # repeat importance for group convolutions
+ if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
+ local_imp = local_imp.repeat(layer.groups)
+ local_imp = local_imp[idxs]
+ similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0)
+ similar_sum = torch.sum(torch.abs(similar_matrix), dim=0)
+ group_imp.append(similar_sum)
+ group_idxs.append(root_idxs)
+
+ # FPGMImportance should not care about BatchNorm and LayerNorm
+
+ if len(group_imp) == 0: # skip groups without parameterized layers
+ return None
+
+ group_imp = self._reduce(group_imp, group_idxs)
+ group_imp = self._normalize(group_imp, self.normalizer)
+ return group_imp
+
class RandomImportance(Importance):
""" Random importance estimator
Example:
@@ -450,6 +535,125 @@ def __call__(self, group):
group_imp = self._normalize(group_imp, self.normalizer)
return group_imp
+class OBDCImportance(GroupNormImportance):
+ """EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis:
+ http://proceedings.mlr.press/v97/wang19g/wang19g.pdf
+ """
+ def __init__(self,
+ group_reduction:str="mean",
+ normalizer:str='mean',
+ bias=False,
+ target_types:list=[nn.modules.conv._ConvNd, nn.Linear],
+ num_classes=100):
+ self.group_reduction = group_reduction
+ self.normalizer = normalizer
+ self.target_types = target_types
+ self.bias = bias
+ self.A, self.DS = {}, {}
+ self.Fisher = {}
+ self.MatGradHandler = ComputeMatGrad()
+ self.steps = 0
+ self.eps = 1e-10
+ self.modules = []
+ self.num_classes = num_classes
+ self.known_modules = {'Linear', 'Conv2d'}
+
+ def step(self):
+ with torch.no_grad():
+ for m in self.modules:
+ A, DS = self.A[m], self.DS[m]
+ grad_mat = self.MatGradHandler(A, DS, m)
+ grad_mat *= DS.size(0)
+ if self.steps == 0:
+ self.Fisher[m] = grad_mat.new(grad_mat.size()[1:]).fill_(0)
+ self.Fisher[m] += (grad_mat.pow_(2)).sum(0)
+ self.A[m] = None
+ self.DS[m] = None
+ self.steps += 1
+
+ def adjust_fisher(self, group, idxs):
+ for i, (dep, id) in enumerate(group):
+ layer = dep.target.module
+ if layer in self.modules:
+ if layer.weight.grad is not None:
+ shape = layer.weight.shape
+ if isinstance(layer, nn.modules.conv._ConvNd):
+ kernel_size = shape[2]*shape[3]
+ else:
+ kernel_size = 1
+ indices_to_keep = list(range(self.Fisher[layer].shape[1]))
+ for idx in idxs:
+ indices_to_keep = [i for i in indices_to_keep if not (idx*kernel_size <= i < (idx+1)*kernel_size)]
+ self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.LongTensor(indices_to_keep).to(self.Fisher[layer].device))
+
+
+ def _rm_hooks(self, model):
+ for m in self.modules:
+ for h in self._hooks:
+ h.remove()
+
+ def _save_input(self, module, input):
+ self.A[module] = input[0].data
+
+ def _save_grad_output(self, module, grad_input, grad_output):
+ self.DS[module] = grad_output[0].data
+
+ def _prepare_model(self, model, pruner):
+ self._hooks = []
+ for group in pruner.DG.get_all_groups(ignored_layer_inputs=pruner.ignored_layer_inputs, ignored_layer_outputs=pruner.ignored_layer_outputs, target_layers=pruner.target_layers, target_layer_types=pruner.target_layer_types):
+ group = pruner._downstream_node_as_root_if_attention(group)
+ for i, (dep, idxs) in enumerate(group):
+ layer = dep.target.module
+ if isinstance(layer, tuple(self.target_types)) and dep.handler in [
+ function.prune_conv_out_channels,
+ function.prune_linear_out_channels,
+ ]:
+ self.modules.append(layer)
+ self._hooks.append(layer.register_forward_pre_hook(self._save_input))
+ self._hooks.append(layer.register_backward_hook(self._save_grad_output))
+
+ def _clear_buffer(self):
+ self.Fisher = {}
+ self.modules = []
+ self.steps = 0
+
+ @torch.no_grad()
+ def __call__(self, group):
+ group_imp = []
+ group_idxs = []
+ for i, (dep, idxs) in enumerate(group):
+ idxs.sort()
+ layer = dep.target.module
+ prune_fn = dep.handler
+ root_idxs = group[i].root_idxs
+ if not isinstance(layer, tuple(self.target_types)) or (isinstance(layer, torch.nn.Linear) and layer.out_features == self.num_classes):
+ continue
+ F_diag = (self.Fisher[layer] / self.steps + self.eps)
+ if prune_fn in [
+ function.prune_conv_out_channels,
+ function.prune_linear_out_channels,
+ ]:
+ if layer.weight.grad is not None:
+ if hasattr(layer, "transposed") and layer.transposed:
+ w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
+ else:
+ w = layer.weight.data[idxs].flatten(1)
+ local_imp = (w ** 2 * F_diag).sum(1)
+ group_imp.append(local_imp)
+ group_idxs.append(root_idxs)
+
+ if self.bias and layer.bias is not None and layer.bias.grad is not None:
+ b = layer.bias.data[idxs]
+ local_imp = (b ** 2 * F_diag).sum(1)
+ group_imp.append(local_imp)
+ group_idxs.append(root_idxs)
+
+ if len(group_imp) == 0: # skip groups without parameterized layers
+ return None
+ group_imp = self._reduce(group_imp, group_idxs)
+ group_imp = self._normalize(group_imp, self.normalizer)
+ return group_imp
+
class GroupOBDImportance(GroupNormImportance):
"""Grouped Optimal Brain Damage:
https://proceedings.neurips.cc/paper/1989/hash/6c9882bbac1c7093bd25041881277658-Abstract.html
@@ -615,4 +819,5 @@ class TaylorImportance(GroupTaylorImportance):
pass
class OBDImportance(GroupOBDImportance):
- pass
\ No newline at end of file
+ pass
+
diff --git a/torch_pruning/utils/compute_mat_grad.py b/torch_pruning/utils/compute_mat_grad.py
new file mode 100644
index 00000000..e89051ed
--- /dev/null
+++ b/torch_pruning/utils/compute_mat_grad.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def _extract_patches(x, kernel_size, stride, padding):
+ """
+ :param x: The input feature maps. (batch_size, in_c, h, w)
+ :param kernel_size: the kernel size of the conv filter (tuple of two elements)
+ :param stride: the stride of conv operation (tuple of two elements)
+ :param padding: number of paddings. be a tuple of two elements
+ :return: (batch_size, out_h, out_w, in_c*kh*kw)
+ """
+ if padding[0] + padding[1] > 0:
+ x = F.pad(x, (padding[1], padding[1], padding[0],
+ padding[0])).data # Actually check dims
+ x = x.unfold(2, kernel_size[0], stride[0])
+ x = x.unfold(3, kernel_size[1], stride[1])
+ x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
+ x = x.view(
+ x.size(0), x.size(1), x.size(2),
+ x.size(3) * x.size(4) * x.size(5))
+ return x
+
+
+
+def try_contiguous(x):
+ if not x.is_contiguous():
+ x = x.contiguous()
+
+ return x
+
+
+class ComputeMatGrad:
+
+ @classmethod
+ def __call__(cls, input, grad_output, layer):
+ if isinstance(layer, nn.Linear):
+ grad = cls.linear(input, grad_output, layer)
+ elif isinstance(layer, nn.Conv2d):
+ grad = cls.conv2d(input, grad_output, layer)
+ else:
+ raise NotImplementedError
+ return grad
+
+ @staticmethod
+ def linear(input, grad_output, layer):
+ """
+ :param input: batch_size * input_dim
+ :param grad_output: batch_size * output_dim
+ :param layer: [nn.module] output_dim * input_dim
+ :return: batch_size * output_dim * (input_dim + [1 if with bias])
+ """
+ with torch.no_grad():
+ if layer.bias is not None:
+ input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
+ input = input.unsqueeze(1)
+ grad_output = grad_output.unsqueeze(2)
+ grad = torch.bmm(grad_output, input)
+ return grad
+
+ @staticmethod
+ def conv2d(input, grad_output, layer):
+ """
+ :param input: batch_size * in_c * in_h * in_w
+ :param grad_output: batch_size * out_c * h * w
+ :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
+ :return:
+ """
+ with torch.no_grad():
+ input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
+ input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw
+ grad_output = grad_output.transpose(1, 2).transpose(2, 3)
+ grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
+ # b * hw * out_c
+ if layer.bias is not None:
+ input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
+ input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw
+ grad = torch.einsum('abm,abn->amn', (grad_output, input))
+ return grad
\ No newline at end of file